aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Vasilache <nico.vasilache@amd.com>2025-06-20 19:07:52 +0200
committerNicolas Vasilache <nico.vasilache@amd.com>2025-06-21 12:25:35 +0200
commite7ab6811047de13847a624287f32af932b8c6427 (patch)
treecf8eaf767b3e67caa98f1891078cee59c39d2a61
parentd6a486c221c1a2d18e88ca39279bcf1675fe7723 (diff)
downloadllvm-users/nico/revisit-pad.zip
llvm-users/nico/revisit-pad.tar.gz
llvm-users/nico/revisit-pad.tar.bz2
[mlir][transform][tensor] Add GetDimOpusers/nico/revisit-pad
-rw-r--r--mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td19
-rw-r--r--mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp32
-rw-r--r--mlir/test/Dialect/Tensor/transform-op-simple.mlir18
3 files changed, 69 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 9f6387d..45264b9 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -125,6 +125,25 @@ def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
"(`aggressive` $aggressive^)? attr-dict";
}
+def GetDimOp : TransformDialectOp<"tensor.get_dim",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformEachOpTrait]> {
+ let summary = "Get a handle to the result of a newly created DimOp";
+ let description = [{
+ This transform always succeeds, it is the user's responsibility that the op
+ arguments makes sense and verifies.
+ }];
+
+ let arguments = (ins TransformValueHandleTypeInterface:$target,
+ I64Attr:$rank);
+ let results = (outs TransformValueHandleTypeInterface:$result);
+ let assemblyFormat =
+ "$target `[` $rank `]` "
+ "attr-dict `:` functional-type(operands, results)";
+ let hasVerifier = 1;
+}
+
def Transform_TensorPadOp : Transform_ConcreteOpType<"tensor.pad">;
def MakeLoopIndependentOp
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 723731b..0727ace 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
@@ -144,6 +145,37 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
}
//===----------------------------------------------------------------------===//
+// GetDimOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetDimOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Value> dims;
+ for (Value v : state.getPayloadValues(getTarget())) {
+ Location loc = v.getLoc();
+ Value cst = rewriter.create<arith::ConstantIndexOp>(loc, getRank());
+ Value dim = rewriter.create<tensor::DimOp>(loc, v, cst);
+ dims.push_back(dim);
+ }
+ results.setValues(cast<OpResult>(getResult()), dims);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::GetDimOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ onlyReadsPayload(effects);
+}
+
+LogicalResult transform::GetDimOp::verify() {
+ // TODO: could verify rank.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// TypeConversionCastTensorShapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/transform-op-simple.mlir b/mlir/test/Dialect/Tensor/transform-op-simple.mlir
new file mode 100644
index 0000000..e7b827c
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/transform-op-simple.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @create_dim(
+// CHECK-NEXT: tensor.dim
+func.func @create_dim() -> tensor<8x16xf32> {
+ %s = arith.constant 1.0 : f32
+ %t = tensor.splat %s : tensor<8x16xf32>
+ return %t: tensor<8x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %splat = transform.structured.match ops{["tensor.splat"]} in %module : (!transform.any_op) -> !transform.any_op
+ %t = transform.get_operand %splat[0] : (!transform.any_op) -> !transform.any_value
+ %_ = transform.tensor.get_dim %t[0] : (!transform.any_value) -> !transform.any_value
+ transform.yield
+ }
+}