diff options
author | lorenzo chelini <l.chelini@icloud.com> | 2022-09-19 12:11:04 +0200 |
---|---|---|
committer | Lorenzo Chelini <l.chelini@icloud.com> | 2022-09-19 12:11:54 +0200 |
commit | f381768a8da6bd6bde8bdff34f080bf12bf20064 (patch) | |
tree | 8f1c30aed7bb315433296acc8050f848413461ea /mlir/python | |
parent | 393cc6a354c625abadff1d45c53b847ba36b45f2 (diff) | |
download | llvm-f381768a8da6bd6bde8bdff34f080bf12bf20064.zip llvm-f381768a8da6bd6bde8bdff34f080bf12bf20064.tar.gz llvm-f381768a8da6bd6bde8bdff34f080bf12bf20064.tar.bz2 |
[MLIR][Linalg] introduce batch-reduce GEMM
The batch-reduce GEMM kernel essentially multiplies a sequence of input tensor
blocks (which form a batch) and the partial multiplication results are reduced
into a single output tensor block.
See: https://ieeexplore.ieee.org/document/9139809 for more details.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D134163
Diffstat (limited to 'mlir/python')
-rw-r--r-- | mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 983842c..b9b292d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -150,6 +150,20 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) +@linalg_structured_op +def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs a batch-reduce matrix multiplication of two 3D inputs. + The partial multiplication results are reduced into a 2D output. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed( + U, B[D.b, D.k, D.n]) @linalg_structured_op def matvec(A=TensorDef(T1, S.M, S.N), |