aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Interfaces/SubsetOpInterface.cpp
blob: d0bdadf500f6f6ccd48441723bfc2fc4b26e9b7d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
//===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"

using namespace mlir;

OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
  auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
  assert(dstOp && "getDestination must be implemented for non-DPS ops");
  assert(
      dstOp.getNumDpsInits() == 1 &&
      "getDestination must be implemented for ops with 0 or more than 1 init");
  return *dstOp.getDpsInitOperand(0);
}

OpResult detail::defaultGetUpdatedDestination(Operation *op) {
  auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
  assert(dstOp && "getUpdatedDestination must be implemented for non-DPS ops");
  auto insertionOp = cast<SubsetInsertionOpInterface>(op);
  return dstOp.getTiedOpResult(&insertionOp.getDestinationOperand());
}

bool detail::defaultIsEquivalentSubset(
    Operation *op, Value candidate,
    function_ref<bool(Value, Value)> equivalenceFn) {
  assert(isa<SubsetInsertionOpInterface>(op) &&
         "expected SubsetInsertionOpInterface");
  if (!candidate.getDefiningOp<SubsetExtractionOpInterface>())
    return false;
  return cast<SubsetOpInterface>(op).operatesOnEquivalentSubset(
      candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
}

bool detail::defaultOperatesOnEquivalentSubset(
    Operation *op, SubsetOpInterface candidate,
    function_ref<bool(Value, Value)> equivalenceFn) {
  auto subsetOp = cast<SubsetOpInterface>(op);
  FailureOr<HyperrectangularSlice> slice =
      subsetOp.getAccessedHyperrectangularSlice();
  assert(succeeded(slice) &&
         "operatesOnEquivalentSubset must be implemented if "
         "getAccessedHyperrectangularSlice is not implemented");
  FailureOr<HyperrectangularSlice> otherSlice =
      candidate.getAccessedHyperrectangularSlice();
  if (failed(otherSlice))
    return false;
  if (!equivalenceFn(subsetOp.getTensorContainer(),
                     candidate.getTensorContainer()))
    return false;
  FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
      op->getContext(), *slice, *otherSlice);
  return succeeded(equivalent) && *equivalent;
}

bool detail::defaultOperatesOnDisjointSubset(
    Operation *op, SubsetOpInterface candidate,
    function_ref<bool(Value, Value)> equivalenceFn) {
  auto subsetOp = cast<SubsetOpInterface>(op);
  FailureOr<HyperrectangularSlice> slice =
      subsetOp.getAccessedHyperrectangularSlice();
  assert(succeeded(slice) &&
         "defaultOperatesOnDisjointSubset must be implemented if "
         "getAccessedHyperrectangularSlice is not implemented");
  FailureOr<HyperrectangularSlice> otherSlice =
      candidate.getAccessedHyperrectangularSlice();
  if (failed(otherSlice))
    return false;
  if (!equivalenceFn(subsetOp.getTensorContainer(),
                     candidate.getTensorContainer()))
    return false;
  FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
      op->getContext(), *slice, *otherSlice);
  return succeeded(overlapping) && !*overlapping;
}

Value detail::getTensorContainer(Operation *op) {
  if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
    return insertionOp.getDestinationOperand().get();
  return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
}

LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
  if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
        isa<SubsetInsertionOpInterface>(op.getOperation())))
    return op->emitOpError(
        "SubsetOpInterface ops must implement either "
        "SubsetExtractionOpInterface or SubsetInsertionOpInterface");
  return success();
}

LogicalResult
detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op) {
  if (op->getNumResults() != 1)
    return op->emitOpError(
        "SubsetExtractionOpInterface ops must have one result");
  return success();
}