aboutsummaryrefslogtreecommitdiff
path: root/mlir/python
diff options
context:
space:
mode:
authorlorenzo chelini <l.chelini@icloud.com>2022-09-19 12:11:04 +0200
committerLorenzo Chelini <l.chelini@icloud.com>2022-09-19 12:11:54 +0200
commitf381768a8da6bd6bde8bdff34f080bf12bf20064 (patch)
tree8f1c30aed7bb315433296acc8050f848413461ea /mlir/python
parent393cc6a354c625abadff1d45c53b847ba36b45f2 (diff)
downloadllvm-f381768a8da6bd6bde8bdff34f080bf12bf20064.zip
llvm-f381768a8da6bd6bde8bdff34f080bf12bf20064.tar.gz
llvm-f381768a8da6bd6bde8bdff34f080bf12bf20064.tar.bz2
[MLIR][Linalg] introduce batch-reduce GEMM
The batch-reduce GEMM kernel essentially multiplies a sequence of input tensor blocks (which form a batch) and the partial multiplication results are reduced into a single output tensor block. See: https://ieeexplore.ieee.org/document/9139809 for more details. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D134163
Diffstat (limited to 'mlir/python')
-rw-r--r--mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 983842c..b9b292d 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -150,6 +150,20 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
+@linalg_structured_op
+def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
+ """Performs a batch-reduce matrix multiplication of two 3D inputs.
+ The partial multiplication results are reduced into a 2D output.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed(
+ U, B[D.b, D.k, D.n])
@linalg_structured_op
def matvec(A=TensorDef(T1, S.M, S.N),