aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp2
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp52
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp116
-rw-r--r--mlir/lib/Dialect/Math/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp80
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp23
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp6
-rw-r--r--mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp184
-rw-r--r--mlir/lib/IR/Builders.cpp7
-rw-r--r--mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp2
11 files changed, 433 insertions, 42 deletions
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index f18298e..836f44fd 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -127,7 +127,7 @@ public:
mlirPythonFrozenRewritePatternSetToCapsule(get()));
}
- static nb::object createFromCapsule(nb::object capsule) {
+ static nb::object createFromCapsule(const nb::object &capsule) {
MlirFrozenRewritePatternSet rawPm =
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
if (rawPm.ptr == nullptr)
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index cc6314c..a6f816a 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering
}
};
+struct DistinctObjectsOpLowering
+ : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
+ using ConvertOpToLLVMPattern<
+ memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
+ explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
+ : ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {}
+
+ LogicalResult
+ matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ValueRange operands = adaptor.getOperands();
+ if (operands.size() <= 1) {
+ // Fast path.
+ rewriter.replaceOp(op, operands);
+ return success();
+ }
+
+ Location loc = op.getLoc();
+ SmallVector<Value> ptrs;
+ for (auto [origOperand, newOperand] :
+ llvm::zip_equal(op.getOperands(), operands)) {
+ auto memrefType = cast<MemRefType>(origOperand.getType());
+ MemRefDescriptor memRefDescriptor(newOperand);
+ Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+ memrefType);
+ ptrs.push_back(ptr);
+ }
+
+ auto cond =
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
+ // Generate separate_storage assumptions for each pair of pointers.
+ for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
+ for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
+ Value ptr1 = ptrs[i];
+ Value ptr2 = ptrs[j];
+ LLVM::AssumeOp::create(rewriter, loc, cond,
+ LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
+ }
+ }
+
+ rewriter.replaceOp(op, operands);
+ return success();
+ }
+};
+
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
@@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
patterns.add<
AllocaOpLowering,
AllocaScopeOpLowering,
- AtomicRMWOpLowering,
AssumeAlignmentOpLowering,
+ AtomicRMWOpLowering,
ConvertExtractAlignedPointerAsIndex,
DimOpLowering,
+ DistinctObjectsOpLowering,
ExtractStridedMetadataOpLowering,
GenericAtomicRMWOpLowering,
GetGlobalMemrefOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
- MemorySpaceCastOpLowering,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
+ MemorySpaceCastOpLowering,
PrefetchOpLowering,
RankOpLowering,
- ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
+ ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
StoreOpLowering,
SubViewOpLowering,
TransposeOpLowering,
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index f3e065a..9821a75 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -246,6 +246,6 @@ void SimplifyAffineMinMaxPass::runOnOperation() {
patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
func.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- if (failed(applyPatternsGreedily(func, std::move(frozenPatterns))))
+ if (failed(applyPatternsGreedily(func, frozenPatterns)))
return signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3f0b0ba..dd9b4c2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -42,6 +42,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
@@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- Attribute memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memory_space=*/memorySpace);
-}
-
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- int64_t memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memory_space=*/b.getI64IntegerAttr(memorySpace));
-}
-
namespace {
class NewOpsListener : public RewriterBase::ForwardingListener {
public:
@@ -409,6 +384,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
}
//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+/// Return true if the operand may be read from by its owner. This is currently
+/// very conservative and only looks inside linalg operations to prevent
+/// unintentional data loss.
+static bool mayBeRead(OpOperand &operand) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+
+ // Be conservative about ops we cannot analyze deeper.
+ if (!linalgOp)
+ return true;
+
+ // Look inside linalg ops.
+ Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
+ return !blockArgument.use_empty();
+}
+
+/// Return true if the value may be read through any of its uses.
+static bool mayBeRead(Value value) {
+ // If the value has a reference semantics, it
+ // may be read through any alias...
+ if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
+ return true;
+ return llvm::any_of(value.getUses(),
+ static_cast<bool (&)(OpOperand &)>(mayBeRead));
+}
+
+DiagnosedSilenceableFailure
+transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Value> promoted;
+ for (Value tensor : state.getPayloadValues(getTensor())) {
+ auto type = dyn_cast<RankedTensorType>(tensor.getType());
+ if (!type) {
+ return emitSilenceableError() << "non-tensor type: " << tensor;
+ }
+
+ Operation *definingOp = tensor.getDefiningOp();
+ if (definingOp)
+ rewriter.setInsertionPointAfter(definingOp);
+ else
+ rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
+
+ // Check this before we emit operations using this value.
+ bool needsMaterialization = mayBeRead(tensor);
+
+ SmallVector<Value> dynamicDims;
+ llvm::SmallPtrSet<Operation *, 4> preservedOps;
+ for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
+ if (!ShapedType::isDynamic(dim))
+ continue;
+ Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
+ auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+ preservedOps.insert(dimOp);
+ dynamicDims.push_back(dimOp);
+ }
+ auto allocation = rewriter.create<bufferization::AllocTensorOp>(
+ tensor.getLoc(), type, dynamicDims);
+ // Set memory space if provided.
+ if (getMemorySpaceAttr())
+ allocation.setMemorySpaceAttr(getMemorySpaceAttr());
+ Value allocated = allocation;
+
+ // Only insert a materialization (typically bufferizes to a copy) when the
+ // value may be read from.
+ if (needsMaterialization) {
+ auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
+ tensor.getLoc(), tensor, allocated);
+ preservedOps.insert(copy);
+ promoted.push_back(copy.getResult());
+ } else {
+ promoted.push_back(allocated);
+ }
+ rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
+ }
+ results.setValues(cast<OpResult>(getPromoted()), promoted);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PromoteTensorOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTensorMutable(), effects);
+ transform::producesHandle(getOperation()->getOpResults(), effects);
+ transform::modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index ff62b51..8899c3a 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRMathTransforms
ExpandOps.cpp
ExtendToSupportedTypes.cpp
PolynomialApproximation.cpp
+ SincosFusion.cpp
UpliftToFMA.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
new file mode 100644
index 0000000..69407df
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -0,0 +1,80 @@
+//===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::math;
+
+namespace {
+
+/// Fuse a math.sin and math.cos in the same block that use the same operand and
+/// have identical fastmath flags into a single math.sincos.
+struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(math::SinOp sinOp,
+ PatternRewriter &rewriter) const override {
+ Value operand = sinOp.getOperand();
+ mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
+
+ math::CosOp cosOp = nullptr;
+ sinOp->getBlock()->walk([&](math::CosOp op) {
+ if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
+ cosOp = op;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+
+ if (!cosOp)
+ return failure();
+
+ Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation()
+ : cosOp.getOperation();
+ rewriter.setInsertionPoint(firstOp);
+
+ Type elemType = sinOp.getType();
+ auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(),
+ TypeRange{elemType, elemType}, operand,
+ sinOp.getFastmathAttr());
+
+ rewriter.replaceOp(sinOp, sincos.getSin());
+ rewriter.replaceOp(cosOp, sincos.getCos());
+ return success();
+ }
+};
+
+} // namespace
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+namespace {
+
+struct MathSincosFusionPass final
+ : math::impl::MathSincosFusionPassBase<MathSincosFusionPass> {
+ using MathSincosFusionPassBase::MathSincosFusionPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<SincosFusionPattern>(&getContext());
+
+ GreedyRewriteConfig config;
+ if (failed(
+ applyPatternsGreedily(getOperation(), std::move(patterns), config)))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 349b4de..e9bdcda 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -607,6 +607,29 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
}
//===----------------------------------------------------------------------===//
+// DistinctObjectsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DistinctObjectsOp::verify() {
+ if (getOperandTypes() != getResultTypes())
+ return emitOpError("operand types and result types must match");
+
+ if (getOperandTypes().empty())
+ return emitOpError("expected at least one operand");
+
+ return success();
+}
+
+LogicalResult DistinctObjectsOp::inferReturnTypes(
+ MLIRContext * /*context*/, std::optional<Location> /*location*/,
+ ValueRange operands, DictionaryAttr /*attributes*/,
+ OpaqueProperties /*properties*/, RegionRange /*regions*/,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index f01ad05..a173cf1 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -182,7 +182,7 @@ static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
}
template <typename ClauseAttr>
-void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
+static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
p << stringifyEnum(attr.getValue());
}
@@ -1511,8 +1511,8 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
//===----------------------------------------------------------------------===//
// Helper function to get bitwise AND of `value` and 'flag'
-uint64_t mapTypeToBitFlag(uint64_t value,
- llvm::omp::OpenMPOffloadMappingFlags flag) {
+static uint64_t mapTypeToBitFlag(uint64_t value,
+ llvm::omp::OpenMPOffloadMappingFlags flag) {
return value & llvm::to_underlying(flag);
}
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index 842e880..c627158 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -6,13 +6,24 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
using namespace mlir;
+static ParseResult parseAlternativesOpSelectedRegion(
+ OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+ std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+ Operation *op,
+ IntegerAttr selectedRegionAttr,
+ Value selectedRegionParam);
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
@@ -57,3 +68,176 @@ LogicalResult transform::tune::KnobOp::verify() {
return success();
}
+
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAlternativesOpSelectedRegion(
+ OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+ std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
+ size_t selectedRegionIdx;
+ OptionalParseResult attrParseRes =
+ parser.parseOptionalInteger(selectedRegionIdx);
+ if (attrParseRes.has_value()) {
+ if (failed(*attrParseRes))
+ return failure();
+
+ selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand param;
+ auto paramParseRes = parser.parseOptionalOperand(param);
+ if (paramParseRes.has_value()) {
+ if (failed(*paramParseRes))
+ return failure();
+
+ selectedRegionParam = param;
+ return success();
+ }
+
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected either an integer attribute or a transform.param operand";
+}
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+ Operation *op,
+ IntegerAttr selectedRegionAttr,
+ Value selectedRegionParam) {
+ if (selectedRegionAttr)
+ printer << selectedRegionAttr.getValue();
+ if (selectedRegionParam)
+ printer << selectedRegionParam;
+}
+
+OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
+ RegionBranchPoint point) {
+ // No operands will be forwarded to the region(s).
+ return getOperands().slice(0, 0);
+}
+
+void transform::tune::AlternativesOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+ if (point.isParent())
+ if (auto selectedRegionIdx = getSelectedRegionAttr())
+ regions.emplace_back(
+ &getAlternatives()[selectedRegionIdx->getSExtValue()],
+ Block::BlockArgListType());
+ else
+ for (Region &alternative : getAlternatives())
+ regions.emplace_back(&alternative, Block::BlockArgListType());
+ else
+ regions.emplace_back(getOperation()->getResults());
+}
+
+void transform::tune::AlternativesOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+ (void)operands;
+ bounds.reserve(getNumRegions());
+
+ if (auto selectedRegionIdx = getSelectedRegionAttr()) {
+ bounds.resize(getNumRegions(), InvocationBounds(0, 0));
+ bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
+ } else {
+ bounds.resize(getNumRegions(), InvocationBounds(0, 1));
+ }
+}
+
+void transform::tune::AlternativesOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getSelectedRegionParamMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ // TODO: should effects from regions be forwarded?
+}
+
+DiagnosedSilenceableFailure
+transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ std::optional<size_t> selectedRegionIdx;
+
+ if (auto selectedRegionAttr = getSelectedRegionAttr())
+ selectedRegionIdx = selectedRegionAttr->getSExtValue();
+
+ if (Value selectedRegionParam = getSelectedRegionParam()) {
+ ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
+ IntegerAttr selectedRegionAttr;
+ if (associatedAttrs.size() != 1 ||
+ !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
+ return emitDefiniteFailure()
+ << "param should hold exactly one integer attribute, got: "
+ << associatedAttrs[0];
+ selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
+ }
+
+ if (!selectedRegionIdx)
+ return emitDefiniteFailure() << "non-deterministic choice " << getName()
+ << " is only resolved through providing a "
+ "`selected_region` attr/param";
+
+ if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
+ return emitDefiniteFailure()
+ << "'selected_region' attribute/param specifies region at index "
+ << *selectedRegionIdx << " while op has only " << getNumRegions()
+ << " regions";
+
+ Region &selectedRegion = getRegion(*selectedRegionIdx);
+ auto scope = state.make_region_scope(selectedRegion);
+ Block &block = selectedRegion.front();
+ // Apply the region's ops one by one.
+ for (Operation &transform : block.without_terminator()) {
+ DiagnosedSilenceableFailure result =
+ state.applyTransform(cast<transform::TransformOpInterface>(transform));
+ if (result.isDefiniteFailure())
+ return result;
+
+ if (result.isSilenceableFailure()) {
+ for (const auto &res : getResults())
+ results.set(res, {});
+ return result;
+ }
+ }
+ // Forward the operation mapping for values yielded from the region to the
+ // values produced by the alternatives op.
+ transform::detail::forwardTerminatorOperands(&block, state, results);
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::tune::AlternativesOp::verify() {
+ for (auto *region : getRegions()) {
+ auto yieldTerminator =
+ llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
+ if (!yieldTerminator)
+ return emitOpError() << "expected '"
+ << transform::YieldOp::getOperationName()
+ << "' as terminator";
+
+ if (yieldTerminator->getNumOperands() != getNumResults())
+ return yieldTerminator.emitOpError()
+ << "expected terminator to have as many operands as the parent op "
+ "has results";
+
+ for (auto [i, operandType, resultType] : llvm::zip_equal(
+ llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+ yieldTerminator->getOperands().getType(), getResultTypes())) {
+ if (operandType == resultType)
+ continue;
+ return yieldTerminator.emitOpError()
+ << "the type of the terminator operand #" << i
+ << " must match the type of the corresponding parent op result ("
+ << operandType << " vs " << resultType << ")";
+ }
+ }
+
+ if (auto selectedRegionAttr = getSelectedRegionAttr()) {
+ size_t regionIdx = selectedRegionAttr->getSExtValue();
+ if (regionIdx < 0 || regionIdx >= getNumRegions())
+ return emitOpError()
+ << "'selected_region' attribute specifies region at index "
+ << regionIdx << " while op has only " << getNumRegions()
+ << " regions";
+ }
+
+ return success();
+}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index c84e760..8f199b6 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -489,13 +489,6 @@ OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
SmallVector<OpFoldResult, 4> foldResults;
LDBG() << "Trying to fold: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
- if (op->getName().getStringRef() == "vector.extract") {
- Operation *parent = op->getParentOp();
- while (parent && parent->getName().getStringRef() != "spirv.func")
- parent = parent->getParentOp();
- if (parent)
- parent->dump();
- }
if (failed(op->fold(foldResults)))
return cleanupFailure();
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index af4ea5a..0f28cbc 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -304,7 +304,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
umin = lhsMin.udiv(rhsMax);
// X u/ Y u<= X.
- APInt umax = lhsMax;
+ const APInt &umax = lhsMax;
return ConstantIntRanges::fromUnsigned(umin, umax);
}