aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2024-07-24 11:31:02 +0200
committerMatthias Springer <mspringer@nvidia.com>2024-07-24 11:31:02 +0200
commitf3d460a0477623736ef14a91cac8eabe378c0ae7 (patch)
treeea96709e2778b173385c9860071a5f7ebdbf9b91
parentba8126b6fef79bd344a247f6291aaec7b67bdff0 (diff)
downloadllvm-users/matthias-springer/value_get_owning_op.zip
llvm-users/matthias-springer/value_get_owning_op.tar.gz
llvm-users/matthias-springer/value_get_owning_op.tar.bz2
[mlir] Add `Value::getOwningOp`users/matthias-springer/value_get_owning_op
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h7
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td2
-rw-r--r--mlir/include/mlir/IR/Value.h18
-rw-r--r--mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp6
-rw-r--r--mlir/lib/Analysis/SliceAnalysis.cpp7
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp13
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp26
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp2
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp4
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp16
-rw-r--r--mlir/lib/IR/Value.cpp6
-rw-r--r--mlir/lib/Interfaces/ValueBoundsOpInterface.cpp17
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp4
13 files changed, 50 insertions, 78 deletions
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 2fda091..532ba55 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -60,7 +60,8 @@ struct AliasingValue {
bool isDefinite;
};
-template <typename T> class AliasList {
+template <typename T>
+class AliasList {
public:
/// Create an empty list of aliases.
AliasList() = default;
@@ -663,10 +664,6 @@ BaseMemRefType
getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
Attribute memorySpace = nullptr);
-/// Return the owner of the given value. In case of a BlockArgument that is the
-/// owner of the block. In case of an OpResult that is the defining op.
-Operation *getOwnerOfValue(Value value);
-
/// Assuming that the given region is repetitive, find the next enclosing
/// repetitive region.
Region *getNextEnclosingRepetitiveRegion(Region *region,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 80cd13d..f84d366 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(getOwnerOfValue(value) == $_op.getOperation() &&
+ assert(value.getOwningOp() == $_op.getOperation() &&
"expected that value belongs to this op");
assert(invocationStack.back() == value &&
"inconsistant invocation stack");
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index a7344c6..7e4c2e0 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -98,26 +98,23 @@ public:
constexpr Value(detail::ValueImpl *impl = nullptr) : impl(impl) {}
template <typename U>
- [[deprecated("Use mlir::isa<U>() instead")]]
- bool isa() const {
+ [[deprecated("Use mlir::isa<U>() instead")]] bool isa() const {
return llvm::isa<U>(*this);
}
template <typename U>
- [[deprecated("Use mlir::dyn_cast<U>() instead")]]
- U dyn_cast() const {
+ [[deprecated("Use mlir::dyn_cast<U>() instead")]] U dyn_cast() const {
return llvm::dyn_cast<U>(*this);
}
template <typename U>
- [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
- U dyn_cast_or_null() const {
+ [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] U
+ dyn_cast_or_null() const {
return llvm::dyn_cast_or_null<U>(*this);
}
template <typename U>
- [[deprecated("Use mlir::cast<U>() instead")]]
- U cast() const {
+ [[deprecated("Use mlir::cast<U>() instead")]] U cast() const {
return llvm::cast<U>(*this);
}
@@ -154,6 +151,11 @@ public:
Location getLoc() const;
void setLoc(Location loc);
+ /// Return the owning operation of the this value. In case of a
+ /// BlockArgument, it is the owner of the block. In case of an OpResult, it
+ /// is the defining op.
+ Operation *getOwningOp() const;
+
/// Return the Region in which this Value is defined.
Region *getParentRegion();
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 6cece46..6d70e82 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -225,11 +225,7 @@ getAllocEffectFor(Value value,
std::optional<MemoryEffects::EffectInstance> &effect,
Operation *&allocScopeOp) {
// Try to get a memory effect interface for the parent operation.
- Operation *op;
- if (BlockArgument arg = dyn_cast<BlockArgument>(value))
- op = arg.getOwner()->getParentOp();
- else
- op = cast<OpResult>(value).getOwner();
+ Operation *op = value.getOwningOp();
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
if (!interface)
return failure();
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 2b1cf41..4505334 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -132,12 +132,7 @@ void mlir::getBackwardSlice(Operation *op,
void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
- if (Operation *definingOp = root.getDefiningOp()) {
- getBackwardSlice(definingOp, backwardSlice, options);
- return;
- }
- Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
- getBackwardSlice(bbAargOwner, backwardSlice, options);
+ getBackwardSlice(root.getOwningOp(), backwardSlice, options);
}
SetVector<Operation *>
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 5eb531b..a49be11 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -167,16 +167,11 @@ static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
MlirAsmState mlirAsmStateCreateForValue(MlirValue value,
MlirOpPrintingFlags flags) {
- Operation *op;
mlir::Value val = unwrap(value);
- if (auto result = llvm::dyn_cast<OpResult>(val)) {
- op = result.getOwner();
- } else {
- op = llvm::cast<BlockArgument>(val).getOwner()->getParentOp();
- if (!op) {
- emitError(val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
- return {nullptr};
- }
+ Operation *op = val.getOwningOp();
+ if (!op) {
+ emitError(val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
+ return {nullptr};
}
op = findParent(op, unwrap(flags)->shouldUseLocalScope());
return wrap(new AsmState(op, *unwrap(flags)));
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f..b1497b9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -134,12 +134,6 @@ Region *bufferization::getParallelRegion(Region *region,
return nullptr;
}
-Operation *bufferization::getOwnerOfValue(Value value) {
- if (auto opResult = llvm::dyn_cast<OpResult>(value))
- return opResult.getDefiningOp();
- return llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
-}
-
/// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
/// shaped value is copied. Otherwise, a tensor with undefined contents is
/// allocated.
@@ -153,8 +147,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
tensor = b.create<ToTensorOp>(loc, shapedValue);
} else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
- return getOwnerOfValue(shapedValue)
- ->emitError("copying of unranked tensors is not implemented");
+ return shapedValue.getOwningOp()->emitError(
+ "copying of unranked tensors is not implemented");
} else {
llvm_unreachable("expected RankedTensorType or MemRefType");
}
@@ -355,7 +349,7 @@ BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
BufferizableOpInterface
BufferizationOptions::dynCastBufferizableOp(Value value) const {
- return dynCastBufferizableOp(getOwnerOfValue(value));
+ return dynCastBufferizableOp(value.getOwningOp());
}
void BufferizationOptions::setFunctionBoundaryTypeConversion(
@@ -388,9 +382,9 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
/// Determine which OpOperand* will alias with `value` if the op is bufferized
/// in place. Return all tensor OpOperand* if the op is not bufferizable.
AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
- if (Operation *op = getOwnerOfValue(value))
- if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
- return bufferizableOp.getAliasingOpOperands(value, *this);
+ if (auto bufferizableOp =
+ getOptions().dynCastBufferizableOp(value.getOwningOp()))
+ return bufferizableOp.getAliasingOpOperands(value, *this);
// The op is not bufferizable.
return detail::unknownGetAliasingOpOperands(value);
@@ -677,7 +671,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
// Try querying BufferizableOpInterface.
- Operation *op = getOwnerOfValue(value);
+ Operation *op = value.getOwningOp();
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
return bufferizableOp.getBufferType(value, options, invocationStack);
@@ -880,7 +874,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
// * conflictingWrite = %1
//
auto isMemoryWriteInsideOp = [&](Value v) {
- Operation *op = getOwnerOfValue(v);
+ Operation *op = v.getOwningOp();
if (!opResult.getDefiningOp()->isAncestor(op))
return false;
return state.bufferizesToMemoryWrite(v);
@@ -901,7 +895,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
// getAliasingValues.
AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
Value value, const AnalysisState &state) {
- Operation *op = getOwnerOfValue(value);
+ Operation *op = value.getOwningOp();
SmallVector<AliasingOpOperand> result;
for (OpOperand &opOperand : op->getOpOperands()) {
if (!llvm::isa<TensorType>(opOperand.get().getType()))
@@ -924,7 +918,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
return bufferization::getMemRefType(value, options);
// Value is an OpResult.
- Operation *op = getOwnerOfValue(value);
+ Operation *op = value.getOwningOp();
auto opResult = llvm::cast<OpResult>(value);
AnalysisState state(options);
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 975bfb4..b2a3cd7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -227,7 +227,7 @@ bool OneShotAnalysisState::isWritable(Value value) const {
// TODO: Out-of-place bufferized value could be considered writable.
// Query BufferizableOpInterface to see if the BlockArgument is writable.
if (auto bufferizableOp =
- getOptions().dynCastBufferizableOp(getOwnerOfValue(value)))
+ getOptions().dynCastBufferizableOp(value.getOwningOp()))
return bufferizableOp.isWritable(value, *this);
// Not a bufferizable op: The conservative answer is "not writable".
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index cf40443..d00436a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -697,7 +697,7 @@ struct ForOpInterface
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto forOp = cast<scf::ForOp>(op);
- assert(getOwnerOfValue(value) == op && "invalid value");
+ assert(value.getOwningOp() == op && "invalid value");
assert(isa<TensorType>(value.getType()) && "expected tensor type");
if (auto opResult = dyn_cast<OpResult>(value)) {
@@ -1020,7 +1020,7 @@ struct WhileOpInterface
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto whileOp = cast<scf::WhileOp>(op);
- assert(getOwnerOfValue(value) == op && "invalid value");
+ assert(value.getOwningOp() == op && "invalid value");
assert(isa<TensorType>(value.getType()) && "expected tensor type");
// Case 1: Block argument of the "before" region.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e5b1291..f8b661f 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
@@ -3925,15 +3926,10 @@ static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
void Value::printAsOperand(raw_ostream &os,
const OpPrintingFlags &flags) const {
- Operation *op;
- if (auto result = llvm::dyn_cast<OpResult>(*this)) {
- op = result.getOwner();
- } else {
- op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
- if (!op) {
- os << "<<UNKNOWN SSA VALUE>>";
- return;
- }
+ Operation *op = getOwningOp();
+ if (!op) {
+ os << "<<UNKNOWN SSA VALUE>>";
+ return;
}
op = findParent(op, flags.shouldUseLocalScope());
AsmState state(op, flags);
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 1787653..bfb4b62 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -37,6 +37,12 @@ void Value::setLoc(Location loc) {
return llvm::cast<BlockArgument>(*this).setLoc(loc);
}
+Operation *Value::getOwningOp() const {
+ if (auto blockArg = llvm::dyn_cast<BlockArgument>(*this))
+ return blockArg.getOwner()->getParentOp();
+ return llvm::cast<OpResult>(*this).getOwner();
+}
+
/// Return the Region in which this Value is defined.
Region *Value::getParentRegion() {
if (auto *op = getDefiningOp())
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 6420c19..3ab6227d 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,12 +25,6 @@ namespace mlir {
#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
} // namespace mlir
-static Operation *getOwnerOfValue(Value value) {
- if (auto bbArg = dyn_cast<BlockArgument>(value))
- return bbArg.getOwner()->getParentOp();
- return value.getDefiningOp();
-}
-
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides)
@@ -272,7 +266,7 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
<< " for: " << value
<< " (dim: " << dim.value_or(kIndexValue)
- << ", owner: " << getOwnerOfValue(value)->getName()
+ << ", owner: " << value.getOwningOp()->getName()
<< ")\n");
positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
// Update reverse mapping.
@@ -338,7 +332,7 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
#endif // NDEBUG
LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
<< " (dim: " << dim.value_or(kIndexValue)
- << ", owner: " << getOwnerOfValue(value)->getName()
+ << ", owner: " << value.getOwningOp()->getName()
<< ")\n");
auto it =
valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
@@ -390,11 +384,10 @@ void ValueBoundsConstraintSet::processWorklist() {
// Query `ValueBoundsOpInterface` for constraints. New items may be added to
// the worklist.
- auto valueBoundsOp =
- dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+ auto valueBoundsOp = dyn_cast<ValueBoundsOpInterface>(value.getOwningOp());
LLVM_DEBUG(llvm::dbgs()
<< "Query value bounds for: " << value
- << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
+ << " (owner: " << value.getOwningOp()->getName() << ")\n");
if (valueBoundsOp) {
if (dim == kIndexValue) {
valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -892,7 +885,7 @@ void ValueBoundsConstraintSet::dump() const {
} else {
llvm::errs() << valueDim->second << "\t";
}
- llvm::errs() << getOwnerOfValue(valueDim->first)->getName() << " ";
+ llvm::errs() << valueDim->first.getOwningOp()->getName() << " ";
if (OpResult result = dyn_cast<OpResult>(valueDim->first)) {
llvm::errs() << "(result " << result.getResultNumber() << ")";
} else {
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index a452cc3..8be1b1f 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -228,9 +228,7 @@ MemorySlotPromoter::MemorySlotPromoter(
blockIndexCache(blockIndexCache) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
- if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
- return arg.getOwner()->getParentOp() == allocator;
- return slot.ptr.getDefiningOp() == allocator;
+ return slot.ptr.getOwningOp() == allocator;
};
assert(isResultOrNewBlockArgument() &&