diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Bindings/Python/Rewrite.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 52 | ||||
-rw-r--r-- | mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 116 | ||||
-rw-r--r-- | mlir/lib/Dialect/Math/Transforms/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp | 80 | ||||
-rw-r--r-- | mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 23 | ||||
-rw-r--r-- | mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp | 184 | ||||
-rw-r--r-- | mlir/lib/IR/Builders.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp | 2 |
11 files changed, 433 insertions, 42 deletions
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index f18298e..836f44fd 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -127,7 +127,7 @@ public: mlirPythonFrozenRewritePatternSetToCapsule(get())); } - static nb::object createFromCapsule(nb::object capsule) { + static nb::object createFromCapsule(const nb::object &capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index cc6314c..a6f816a 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering } }; +struct DistinctObjectsOpLowering + : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> { + using ConvertOpToLLVMPattern< + memref::DistinctObjectsOp>::ConvertOpToLLVMPattern; + explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {} + + LogicalResult + matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange operands = adaptor.getOperands(); + if (operands.size() <= 1) { + // Fast path. + rewriter.replaceOp(op, operands); + return success(); + } + + Location loc = op.getLoc(); + SmallVector<Value> ptrs; + for (auto [origOperand, newOperand] : + llvm::zip_equal(op.getOperands(), operands)) { + auto memrefType = cast<MemRefType>(origOperand.getType()); + MemRefDescriptor memRefDescriptor(newOperand); + Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), + memrefType); + ptrs.push_back(ptr); + } + + auto cond = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1); + // Generate separate_storage assumptions for each pair of pointers. + for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) { + for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) { + Value ptr1 = ptrs[i]; + Value ptr2 = ptrs[j]; + LLVM::AssumeOp::create(rewriter, loc, cond, + LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2); + } + } + + rewriter.replaceOp(op, operands); + return success(); + } +}; + // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. @@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns( patterns.add< AllocaOpLowering, AllocaScopeOpLowering, - AtomicRMWOpLowering, AssumeAlignmentOpLowering, + AtomicRMWOpLowering, ConvertExtractAlignedPointerAsIndex, DimOpLowering, + DistinctObjectsOpLowering, ExtractStridedMetadataOpLowering, GenericAtomicRMWOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, - MemorySpaceCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, + MemorySpaceCastOpLowering, PrefetchOpLowering, RankOpLowering, - ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, + ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp index f3e065a..9821a75 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp @@ -246,6 +246,6 @@ void SimplifyAffineMinMaxPass::runOnOperation() { patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>( func.getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPatternsGreedily(func, std::move(frozenPatterns)))) + if (failed(applyPatternsGreedily(func, frozenPatterns))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 3f0b0ba..dd9b4c2 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -42,6 +42,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" @@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns( // BufferizeToAllocationOp //===----------------------------------------------------------------------===// -void transform::BufferizeToAllocationOp::build(OpBuilder &b, - OperationState &result, - Value target, - Attribute memorySpace) { - SmallVector<Type> resultTypes; - resultTypes.push_back(b.getType<transform::AnyValueType>()); - resultTypes.push_back(b.getType<transform::AnyOpType>()); - return build(b, result, - /*resultTypes=*/resultTypes, - /*target=*/target, - /*memory_space=*/memorySpace); -} - -void transform::BufferizeToAllocationOp::build(OpBuilder &b, - OperationState &result, - Value target, - int64_t memorySpace) { - SmallVector<Type> resultTypes; - resultTypes.push_back(b.getType<transform::AnyValueType>()); - resultTypes.push_back(b.getType<transform::AnyOpType>()); - return build(b, result, - /*resultTypes=*/resultTypes, - /*target=*/target, - /*memory_space=*/b.getI64IntegerAttr(memorySpace)); -} - namespace { class NewOpsListener : public RewriterBase::ForwardingListener { public: @@ -409,6 +384,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() { } //===----------------------------------------------------------------------===// +// PromoteTensorOp +//===----------------------------------------------------------------------===// + +/// Return true if the operand may be read from by its owner. This is currently +/// very conservative and only looks inside linalg operations to prevent +/// unintentional data loss. +static bool mayBeRead(OpOperand &operand) { + auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner()); + + // Be conservative about ops we cannot analyze deeper. + if (!linalgOp) + return true; + + // Look inside linalg ops. + Value blockArgument = linalgOp.getMatchingBlockArgument(&operand); + return !blockArgument.use_empty(); +} + +/// Return true if the value may be read through any of its uses. +static bool mayBeRead(Value value) { + // If the value has a reference semantics, it + // may be read through any alias... + if (!isa<TensorType, FloatType, IntegerType>(value.getType())) + return true; + return llvm::any_of(value.getUses(), + static_cast<bool (&)(OpOperand &)>(mayBeRead)); +} + +DiagnosedSilenceableFailure +transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector<Value> promoted; + for (Value tensor : state.getPayloadValues(getTensor())) { + auto type = dyn_cast<RankedTensorType>(tensor.getType()); + if (!type) { + return emitSilenceableError() << "non-tensor type: " << tensor; + } + + Operation *definingOp = tensor.getDefiningOp(); + if (definingOp) + rewriter.setInsertionPointAfter(definingOp); + else + rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner()); + + // Check this before we emit operations using this value. + bool needsMaterialization = mayBeRead(tensor); + + SmallVector<Value> dynamicDims; + llvm::SmallPtrSet<Operation *, 4> preservedOps; + for (auto [pos, dim] : llvm::enumerate(type.getShape())) { + if (!ShapedType::isDynamic(dim)) + continue; + Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos); + auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst); + preservedOps.insert(dimOp); + dynamicDims.push_back(dimOp); + } + auto allocation = rewriter.create<bufferization::AllocTensorOp>( + tensor.getLoc(), type, dynamicDims); + // Set memory space if provided. + if (getMemorySpaceAttr()) + allocation.setMemorySpaceAttr(getMemorySpaceAttr()); + Value allocated = allocation; + + // Only insert a materialization (typically bufferizes to a copy) when the + // value may be read from. + if (needsMaterialization) { + auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>( + tensor.getLoc(), tensor, allocated); + preservedOps.insert(copy); + promoted.push_back(copy.getResult()); + } else { + promoted.push_back(allocated); + } + rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps); + } + results.setValues(cast<OpResult>(getPromoted()), promoted); + return DiagnosedSilenceableFailure::success(); +} + +void transform::PromoteTensorOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + transform::onlyReadsHandle(getTensorMutable(), effects); + transform::producesHandle(getOperation()->getOpResults(), effects); + transform::modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index ff62b51..8899c3a 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRMathTransforms ExpandOps.cpp ExtendToSupportedTypes.cpp PolynomialApproximation.cpp + SincosFusion.cpp UpliftToFMA.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp new file mode 100644 index 0000000..69407df --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp @@ -0,0 +1,80 @@ +//===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::math; + +namespace { + +/// Fuse a math.sin and math.cos in the same block that use the same operand and +/// have identical fastmath flags into a single math.sincos. +struct SincosFusionPattern : OpRewritePattern<math::SinOp> { + using Base::Base; + + LogicalResult matchAndRewrite(math::SinOp sinOp, + PatternRewriter &rewriter) const override { + Value operand = sinOp.getOperand(); + mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath(); + + math::CosOp cosOp = nullptr; + sinOp->getBlock()->walk([&](math::CosOp op) { + if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) { + cosOp = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (!cosOp) + return failure(); + + Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation() + : cosOp.getOperation(); + rewriter.setInsertionPoint(firstOp); + + Type elemType = sinOp.getType(); + auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(), + TypeRange{elemType, elemType}, operand, + sinOp.getFastmathAttr()); + + rewriter.replaceOp(sinOp, sincos.getSin()); + rewriter.replaceOp(cosOp, sincos.getCos()); + return success(); + } +}; + +} // namespace + +namespace mlir::math { +#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + +namespace { + +struct MathSincosFusionPass final + : math::impl::MathSincosFusionPassBase<MathSincosFusionPass> { + using MathSincosFusionPassBase::MathSincosFusionPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add<SincosFusionPattern>(&getContext()); + + GreedyRewriteConfig config; + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 349b4de..e9bdcda 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -607,6 +607,29 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) { } //===----------------------------------------------------------------------===// +// DistinctObjectsOp +//===----------------------------------------------------------------------===// + +LogicalResult DistinctObjectsOp::verify() { + if (getOperandTypes() != getResultTypes()) + return emitOpError("operand types and result types must match"); + + if (getOperandTypes().empty()) + return emitOpError("expected at least one operand"); + + return success(); +} + +LogicalResult DistinctObjectsOp::inferReturnTypes( + MLIRContext * /*context*/, std::optional<Location> /*location*/, + ValueRange operands, DictionaryAttr /*attributes*/, + OpaqueProperties /*properties*/, RegionRange /*regions*/, + SmallVectorImpl<Type> &inferredReturnTypes) { + llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes)); + return success(); +} + +//===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index f01ad05..a173cf1 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -182,7 +182,7 @@ static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) { } template <typename ClauseAttr> -void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { +static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { p << stringifyEnum(attr.getValue()); } @@ -1511,8 +1511,8 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { //===----------------------------------------------------------------------===// // Helper function to get bitwise AND of `value` and 'flag' -uint64_t mapTypeToBitFlag(uint64_t value, - llvm::omp::OpenMPOffloadMappingFlags flag) { +static uint64_t mapTypeToBitFlag(uint64_t value, + llvm::omp::OpenMPOffloadMappingFlags flag) { return value & llvm::to_underlying(flag); } diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index 842e880..c627158 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -6,13 +6,24 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" using namespace mlir; +static ParseResult parseAlternativesOpSelectedRegion( + OpAsmParser &parser, IntegerAttr &selectedRegionAttr, + std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam); + +static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, + Operation *op, + IntegerAttr selectedRegionAttr, + Value selectedRegionParam); + #define GET_OP_CLASSES #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc" @@ -57,3 +68,176 @@ LogicalResult transform::tune::KnobOp::verify() { return success(); } + +//===----------------------------------------------------------------------===// +// AlternativesOp +//===----------------------------------------------------------------------===// + +static ParseResult parseAlternativesOpSelectedRegion( + OpAsmParser &parser, IntegerAttr &selectedRegionAttr, + std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) { + size_t selectedRegionIdx; + OptionalParseResult attrParseRes = + parser.parseOptionalInteger(selectedRegionIdx); + if (attrParseRes.has_value()) { + if (failed(*attrParseRes)) + return failure(); + + selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx); + return success(); + } + + OpAsmParser::UnresolvedOperand param; + auto paramParseRes = parser.parseOptionalOperand(param); + if (paramParseRes.has_value()) { + if (failed(*paramParseRes)) + return failure(); + + selectedRegionParam = param; + return success(); + } + + return parser.emitError(parser.getCurrentLocation()) + << "expected either an integer attribute or a transform.param operand"; +} + +static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, + Operation *op, + IntegerAttr selectedRegionAttr, + Value selectedRegionParam) { + if (selectedRegionAttr) + printer << selectedRegionAttr.getValue(); + if (selectedRegionParam) + printer << selectedRegionParam; +} + +OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands( + RegionBranchPoint point) { + // No operands will be forwarded to the region(s). + return getOperands().slice(0, 0); +} + +void transform::tune::AlternativesOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { + if (point.isParent()) + if (auto selectedRegionIdx = getSelectedRegionAttr()) + regions.emplace_back( + &getAlternatives()[selectedRegionIdx->getSExtValue()], + Block::BlockArgListType()); + else + for (Region &alternative : getAlternatives()) + regions.emplace_back(&alternative, Block::BlockArgListType()); + else + regions.emplace_back(getOperation()->getResults()); +} + +void transform::tune::AlternativesOp::getRegionInvocationBounds( + ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { + (void)operands; + bounds.reserve(getNumRegions()); + + if (auto selectedRegionIdx = getSelectedRegionAttr()) { + bounds.resize(getNumRegions(), InvocationBounds(0, 0)); + bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1); + } else { + bounds.resize(getNumRegions(), InvocationBounds(0, 1)); + } +} + +void transform::tune::AlternativesOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getSelectedRegionParamMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + // TODO: should effects from regions be forwarded? +} + +DiagnosedSilenceableFailure +transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + std::optional<size_t> selectedRegionIdx; + + if (auto selectedRegionAttr = getSelectedRegionAttr()) + selectedRegionIdx = selectedRegionAttr->getSExtValue(); + + if (Value selectedRegionParam = getSelectedRegionParam()) { + ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam); + IntegerAttr selectedRegionAttr; + if (associatedAttrs.size() != 1 || + !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0]))) + return emitDefiniteFailure() + << "param should hold exactly one integer attribute, got: " + << associatedAttrs[0]; + selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue(); + } + + if (!selectedRegionIdx) + return emitDefiniteFailure() << "non-deterministic choice " << getName() + << " is only resolved through providing a " + "`selected_region` attr/param"; + + if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions()) + return emitDefiniteFailure() + << "'selected_region' attribute/param specifies region at index " + << *selectedRegionIdx << " while op has only " << getNumRegions() + << " regions"; + + Region &selectedRegion = getRegion(*selectedRegionIdx); + auto scope = state.make_region_scope(selectedRegion); + Block &block = selectedRegion.front(); + // Apply the region's ops one by one. + for (Operation &transform : block.without_terminator()) { + DiagnosedSilenceableFailure result = + state.applyTransform(cast<transform::TransformOpInterface>(transform)); + if (result.isDefiniteFailure()) + return result; + + if (result.isSilenceableFailure()) { + for (const auto &res : getResults()) + results.set(res, {}); + return result; + } + } + // Forward the operation mapping for values yielded from the region to the + // values produced by the alternatives op. + transform::detail::forwardTerminatorOperands(&block, state, results); + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::tune::AlternativesOp::verify() { + for (auto *region : getRegions()) { + auto yieldTerminator = + llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back()); + if (!yieldTerminator) + return emitOpError() << "expected '" + << transform::YieldOp::getOperationName() + << "' as terminator"; + + if (yieldTerminator->getNumOperands() != getNumResults()) + return yieldTerminator.emitOpError() + << "expected terminator to have as many operands as the parent op " + "has results"; + + for (auto [i, operandType, resultType] : llvm::zip_equal( + llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()), + yieldTerminator->getOperands().getType(), getResultTypes())) { + if (operandType == resultType) + continue; + return yieldTerminator.emitOpError() + << "the type of the terminator operand #" << i + << " must match the type of the corresponding parent op result (" + << operandType << " vs " << resultType << ")"; + } + } + + if (auto selectedRegionAttr = getSelectedRegionAttr()) { + size_t regionIdx = selectedRegionAttr->getSExtValue(); + if (regionIdx < 0 || regionIdx >= getNumRegions()) + return emitOpError() + << "'selected_region' attribute specifies region at index " + << regionIdx << " while op has only " << getNumRegions() + << " regions"; + } + + return success(); +} diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index c84e760..8f199b6 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -489,13 +489,6 @@ OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results, SmallVector<OpFoldResult, 4> foldResults; LDBG() << "Trying to fold: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); - if (op->getName().getStringRef() == "vector.extract") { - Operation *parent = op->getParentOp(); - while (parent && parent->getName().getStringRef() != "spirv.func") - parent = parent->getParentOp(); - if (parent) - parent->dump(); - } if (failed(op->fold(foldResults))) return cleanupFailure(); diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp index af4ea5a..0f28cbc 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -304,7 +304,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, umin = lhsMin.udiv(rhsMax); // X u/ Y u<= X. - APInt umax = lhsMax; + const APInt &umax = lhsMax; return ConstantIntRanges::fromUnsigned(umin, umax); } |