diff options
author | Nicolas Vasilache <nico.vasilache@amd.com> | 2025-06-20 19:07:52 +0200 |
---|---|---|
committer | Nicolas Vasilache <nico.vasilache@amd.com> | 2025-06-21 12:25:35 +0200 |
commit | e7ab6811047de13847a624287f32af932b8c6427 (patch) | |
tree | cf8eaf767b3e67caa98f1891078cee59c39d2a61 | |
parent | d6a486c221c1a2d18e88ca39279bcf1675fe7723 (diff) | |
download | llvm-users/nico/revisit-pad.zip llvm-users/nico/revisit-pad.tar.gz llvm-users/nico/revisit-pad.tar.bz2 |
[mlir][transform][tensor] Add GetDimOpusers/nico/revisit-pad
3 files changed, 69 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td index 9f6387d..45264b9 100644 --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -125,6 +125,25 @@ def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect, "(`aggressive` $aggressive^)? attr-dict"; } +def GetDimOp : TransformDialectOp<"tensor.get_dim", + [DeclareOpInterfaceMethods<TransformOpInterface>, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TransformEachOpTrait]> { + let summary = "Get a handle to the result of a newly created DimOp"; + let description = [{ + This transform always succeeds, it is the user's responsibility that the op + arguments makes sense and verifies. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$target, + I64Attr:$rank); + let results = (outs TransformValueHandleTypeInterface:$result); + let assemblyFormat = + "$target `[` $rank `]` " + "attr-dict `:` functional-type(operands, results)"; + let hasVerifier = 1; +} + def Transform_TensorPadOp : Transform_ConcreteOpType<"tensor.pad">; def MakeLoopIndependentOp diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 723731b..0727ace 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" @@ -144,6 +145,37 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( } //===----------------------------------------------------------------------===// +// GetDimOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::GetDimOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector<Value> dims; + for (Value v : state.getPayloadValues(getTarget())) { + Location loc = v.getLoc(); + Value cst = rewriter.create<arith::ConstantIndexOp>(loc, getRank()); + Value dim = rewriter.create<tensor::DimOp>(loc, v, cst); + dims.push_back(dim); + } + results.setValues(cast<OpResult>(getResult()), dims); + return DiagnosedSilenceableFailure::success(); +} + +void transform::GetDimOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + onlyReadsPayload(effects); +} + +LogicalResult transform::GetDimOp::verify() { + // TODO: could verify rank. + return success(); +} + +//===----------------------------------------------------------------------===// // TypeConversionCastTensorShapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/transform-op-simple.mlir b/mlir/test/Dialect/Tensor/transform-op-simple.mlir new file mode 100644 index 0000000..e7b827c --- /dev/null +++ b/mlir/test/Dialect/Tensor/transform-op-simple.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @create_dim( +// CHECK-NEXT: tensor.dim +func.func @create_dim() -> tensor<8x16xf32> { + %s = arith.constant 1.0 : f32 + %t = tensor.splat %s : tensor<8x16xf32> + return %t: tensor<8x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %splat = transform.structured.match ops{["tensor.splat"]} in %module : (!transform.any_op) -> !transform.any_op + %t = transform.get_operand %splat[0] : (!transform.any_op) -> !transform.any_value + %_ = transform.tensor.get_dim %t[0] : (!transform.any_value) -> !transform.any_value + transform.yield + } +} |