diff options
author | Yinying Li <yinyingli@google.com> | 2024-05-02 12:28:34 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-02 12:28:34 -0400 |
commit | e71eacc5b19785bc46ce9c3d8541a0c83c65660e (patch) | |
tree | a44b540654368d137439cd1bf958610d1a373fa2 | |
parent | 0708500ce0149c501e802b7ab6581770cc7a5334 (diff) | |
download | llvm-e71eacc5b19785bc46ce9c3d8541a0c83c65660e.zip llvm-e71eacc5b19785bc46ce9c3d8541a0c83c65660e.tar.gz llvm-e71eacc5b19785bc46ce9c3d8541a0c83c65660e.tar.bz2 |
[mlir][sparse] Support explicit/implicit value for complex type (#90771)
6 files changed, 42 insertions, 9 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt index dd6f103..6f59b69 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt @@ -45,6 +45,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect LINK_LIBS PUBLIC MLIRArithDialect + MLIRComplexDialect MLIRDialect MLIRDialectUtils MLIRIR diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 028a69d..de3d300 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -663,6 +664,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { explicitVal = result; } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { explicitVal = result; + } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { + explicitVal = result; } else { parser.emitError(parser.getNameLoc(), "expected a numeric value for explicitVal"); @@ -678,6 +681,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { implicitVal = result; } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { implicitVal = result; + } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { + implicitVal = result; } else { parser.emitError(parser.getNameLoc(), "expected a numeric value for implicitVal"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h index cf3c35f..d0ef8a6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h @@ -401,9 +401,12 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, // Generates a constant from a validated value carrying attribute. inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) { - if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { - Type tp = cast<TypedAttr>(arrayAttr[0]).getType(); - return builder.create<complex::ConstantOp>(loc, tp, arrayAttr); + if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) { + Type tp = cast<ComplexType>(complexAttr.getType()).getElementType(); + return builder.create<complex::ConstantOp>( + loc, complexAttr.getType(), + builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()), + FloatAttr::get(tp, complexAttr.getImag())})); } return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr)); } diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir index 7eeda9a..7fb1c76 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -80,6 +80,21 @@ func.func private @sparse_csr(tensor<?x?xi64, #CSR_OnlyOnes>) // ----- +#CSR_OnlyOnes = #sparse_tensor.encoding<{ + map = (d0, d1) -> (d0 : dense, d1 : compressed), + posWidth = 64, + crdWidth = 64, + explicitVal = #complex.number<:f32 1.0, 0.0>, + implicitVal = #complex.number<:f32 0.0, 0.0> +}> + +// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = #complex.number<:f32 1.000000e+00, 0.000000e+00> : complex<f32>, implicitVal = #complex.number<:f32 0.000000e+00, 0.000000e+00> : complex<f32> }> +// CHECK-LABEL: func private @sparse_csr( +// CHECK-SAME: tensor<?x?xcomplex<f32>, #[[$CSR_OnlyOnes]]>) +func.func private @sparse_csr(tensor<?x?xcomplex<f32>, #CSR_OnlyOnes>) + +// ----- + #BCSR = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed), }> diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir index 82f3147..be21725 100755 --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir @@ -2,9 +2,9 @@ // RUN: --sparsification-and-bufferization | FileCheck %s #CSR_ones_complex = #sparse_tensor.encoding<{ - map = (d0, d1) -> (d0 : dense, d1 : compressed) -// explicitVal = (1.0, 0.0) : complex<f32>, -// implicitVal = (0.0, 0.0) : complex<f32> + map = (d0, d1) -> (d0 : dense, d1 : compressed), + explicitVal = #complex.number<:f32 1.0, 0.0>, + implicitVal = #complex.number<:f32 0.0, 0.0> }> #CSR_ones_fp = #sparse_tensor.encoding<{ @@ -20,9 +20,17 @@ }> // CHECK-LABEL: func.func @matmul_complex -// -// TODO: make this work -// +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[X:.*]] = memref.load +// CHECK: scf.for +// CHECK: %[[I:.*]] = memref.load +// CHECK: %[[Y:.*]] = memref.load +// CHECK: %[[M:.*]] = complex.add %[[Y]], %[[X]] : complex<f32> +// CHECK: memref.store %[[M]] +// CHECK: } +// CHECK: } +// CHECK: } func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>, %b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>, %c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index acd2d3a..13c246a 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3066,6 +3066,7 @@ cc_library( ":ArithDialect", ":BufferizationInterfaces", ":BytecodeOpInterface", + ":ComplexDialect", ":DialectUtils", ":IR", ":InferTypeOpInterface", |