diff options
Diffstat (limited to 'mlir/lib/Interfaces/MemOpInterfaces.cpp')
-rw-r--r-- | mlir/lib/Interfaces/MemOpInterfaces.cpp | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp new file mode 100644 index 0000000..fe5c717 --- /dev/null +++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp @@ -0,0 +1,73 @@ +//===- MemOpInterfaces.cpp - Memory operation interfaces ---------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/MemOpInterfaces.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" + +using namespace mlir; + +LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) { + auto memCastOp = cast<MemorySpaceCastOpInterface>(op); + + // Verify that the source and target pointers are valid + Value sourcePtr = memCastOp.getSourcePtr(); + Value targetPtr = memCastOp.getTargetPtr(); + + if (!sourcePtr || !targetPtr) { + return op->emitError() + << "memory space cast op must have valid source and target pointers"; + } + + if (sourcePtr.getType().getTypeID() != targetPtr.getType().getTypeID()) { + return op->emitError() + << "expected source and target types of the same kind"; + } + + // Verify the Types are of `PtrLikeTypeInterface` type. + auto sourceType = dyn_cast<PtrLikeTypeInterface>(sourcePtr.getType()); + if (!sourceType) { + return op->emitError() + << "source type must implement `PtrLikeTypeInterface`, but got: " + << sourcePtr.getType(); + } + + auto targetType = dyn_cast<PtrLikeTypeInterface>(targetPtr.getType()); + if (!targetType) { + return op->emitError() + << "target type must implement `PtrLikeTypeInterface`, but got: " + << targetPtr.getType(); + } + + // Verify that the operation has exactly one result + if (op->getNumResults() != 1) { + return op->emitError() + << "memory space cast op must have exactly one result"; + } + + return success(); +} + +FailureOr<std::optional<SmallVector<Value>>> +mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, + ValueRange results) { + MemorySpaceCastOpInterface castOp = + MemorySpaceCastOpInterface::getIfPromotableCast(operand.get()); + + // Bail if the src is not valid. + if (!castOp) + return failure(); + + // Modify the op. + operand.set(castOp.getSourcePtr()); + return std::optional<SmallVector<Value>>(); +} + +#include "mlir/Interfaces/MemOpInterfaces.cpp.inc" |