diff options
Diffstat (limited to 'mlir/lib/Dialect')
20 files changed, 154 insertions, 97 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 585b6da..df955fc 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() { if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { return emitOpError( - "source element types much match (except for fp8) but have ") + "source element types must match (except for fp8/bf8) but have ") << sourceAType << " and " << sourceBType; } - if (!sourceAElemType.isInteger(4) && getK() != 16) { - return emitOpError("K dimension must be 16 for source element type ") - << sourceAElemType; + if (isSrcFloat) { + if (getClamp()) + return emitOpError("clamp flag is not supported for float types"); + if (getUnsignedA() || getUnsignedB()) + return emitOpError("unsigned flags are not supported for float types"); } return success(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index e0a53cd..0c35921 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2716,8 +2716,9 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor, return success(folded); } -OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getRegion()) && "invalid region point"); +OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert((successor.isParent() || successor.getSuccessor() == &getRegion()) && + "invalid region point"); // The initial operands map to the loop arguments after the induction // variable or are forwarded to the results when the trip count is zero. @@ -2726,34 +2727,41 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { void AffineForOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { - assert((point.isParent() || point == getRegion()) && "expected loop region"); + assert((point.isParent() || + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getRegion()) && + "expected loop region"); // The loop may typically branch back to its body or to the parent operation. // If the predecessor is the parent op and the trip count is known to be at // least one, branch into the body using the iterator arguments. And in cases // we know the trip count is zero, it can only branch back to its parent. std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this); - if (point.isParent() && tripCount.has_value()) { - if (tripCount.value() > 0) { - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - return; - } - if (tripCount.value() == 0) { - regions.push_back(RegionSuccessor(getResults())); - return; + if (tripCount.has_value()) { + if (!point.isParent()) { + // From the loop body, if the trip count is one, we can only branch back + // to the parent. + if (tripCount == 1) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } + if (tripCount == 0) + return; + } else { + if (tripCount.value() > 0) { + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + return; + } + if (tripCount.value() == 0) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } } } - // From the loop body, if the trip count is one, we can only branch back to - // the parent. - if (!point.isParent() && tripCount == 1) { - regions.push_back(RegionSuccessor(getResults())); - return; - } - // In all other cases, the loop may branch back to itself or the parent // operation. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } AffineBound AffineForOp::getLowerBound() { @@ -3142,7 +3150,7 @@ void AffineIfOp::getSuccessorRegions( RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); // If the "else" region is empty, branch bach into parent. if (getElseRegion().empty()) { - regions.push_back(getResults()); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } else { regions.push_back( RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); @@ -3152,7 +3160,7 @@ void AffineIfOp::getSuccessorRegions( // If the predecessor is the `else`/`then` region, then branching into parent // op is valid. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } LogicalResult AffineIfOp::verify() { diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index d925c19..a651710 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -216,8 +216,8 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { for (auto condBranch : worklist) { auto loc = condBranch.getLoc(); Block *block = condBranch->getBlock(); - auto newTrueBranch = rewriter.splitBlock(block, block->end()); - auto newFalseBranch = rewriter.splitBlock(block, block->end()); + auto *newTrueBranch = rewriter.splitBlock(block, block->end()); + auto *newFalseBranch = rewriter.splitBlock(block, block->end()); insertJump(loc, newTrueBranch, condBranch.getTrueDest(), condBranch.getTrueDestOperands()); insertJump(loc, newFalseBranch, condBranch.getFalseDest(), @@ -382,7 +382,7 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, // Find or create a live range for `value`. auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator); LiveRange &valueLiveRange = it->second; - auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); + auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); // Add the interval [firstUseOrDef, lastUseInBlock) to the live range. unsigned startOpIdx = operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index dc7b07d..8e4a49d 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -36,8 +36,9 @@ void AsyncDialect::initialize() { constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes"; -OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBodyRegion() && "invalid region index"); +OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBodyRegion() && + "invalid region index"); return getBodyOperands(); } @@ -53,8 +54,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // The `body` region branch back to the parent operation. - if (point == getBodyRegion()) { - regions.push_back(RegionSuccessor(getBodyResults())); + if (!point.isParent() && + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBodyRegion()) { + regions.push_back(RegionSuccessor(getOperation(), getBodyResults())); return; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index b593cca..36a759c 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -562,8 +562,11 @@ LogicalResult BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) { SmallVector<TypeRange> returnOperandTypes(llvm::map_range( op.getFunctionBody().getOps<RegionBranchTerminatorOpInterface>(), - [](RegionBranchTerminatorOpInterface op) { - return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes(); + [&](RegionBranchTerminatorOpInterface branchOp) { + return branchOp + .getSuccessorOperands(RegionSuccessor( + op.getOperation(), op.getOperation()->getResults())) + .getTypes(); })); if (!llvm::all_equal(returnOperandTypes)) return op->emitError( @@ -942,8 +945,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { // about, but we would need to check how many successors there are and under // which condition they are taken, etc. - MutableOperandRange operands = - op.getMutableSuccessorOperands(RegionBranchPoint::parent()); + MutableOperandRange operands = op.getMutableSuccessorOperands( + RegionSuccessor(op.getOperation(), op.getOperation()->getResults())); SmallVector<Value> updatedOwnerships; auto result = deallocation_impl::insertDeallocOpForReturnLike( diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 4754f0b..0992ce14 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -845,7 +845,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); return; } @@ -854,7 +855,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -871,7 +873,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index b5f8dda..6c6d8d2 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, void WarpExecuteOnLane0Op::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index eb2d825..bd25e94 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -495,13 +495,14 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, if (failed(maybePackedDimForEachOperand)) return failure(); packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; - listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims)); LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i]; LDBG() << "maps: " << llvm::interleaved(indexingMaps); LDBG() << "iterators: " << llvm::interleaved(iteratorTypes); LDBG() << "packedDimForEachOperand: " << llvm::interleaved(packedOperandsDims.packedDimForEachOperand); + + listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims)); } // Step 2. Propagate packing to all LinalgOp operands. diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index c551fba..1c21a2f 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { void AllocaScopeOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index 6fa8ce4..69afbca 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -98,6 +98,27 @@ struct RankOpInterface } }; +struct CollapseShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface, + memref::CollapseShapeOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto collapseOp = cast<memref::CollapseShapeOp>(op); + assert(value == collapseOp.getResult() && "invalid value"); + + // Multiply the expressions for the dimensions in the reassociation group. + const ReassociationIndices reassocIndices = + collapseOp.getReassociationIndices()[dim]; + AffineExpr productExpr = + cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]); + for (size_t i = 1; i < reassocIndices.size(); ++i) { + productExpr = + productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]); + } + cstr.bound(value)[dim] == productExpr; + } +}; + struct SubViewOpInterface : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface, SubViewOp> { @@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface<memref::AllocaOp>>(*ctx); memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); + memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>( + *ctx); memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>( *ctx); memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 1ab01d8..2946b53 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -397,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions( } // Otherwise, the region branches back to the parent operation. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } //===----------------------------------------------------------------------===// @@ -405,10 +405,11 @@ void ExecuteRegionOp::getSuccessorRegions( //===----------------------------------------------------------------------===// MutableOperandRange -ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getParentOp().getAfter()) && - "condition op can only exit the loop or branch to the after" - "region"); +ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) { + assert( + (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) && + "condition op can only exit the loop or branch to the after" + "region"); // Pass all operands except the condition to the successor region. return getArgsMutable(); } @@ -426,7 +427,7 @@ void ConditionOp::getSuccessorRegions( regions.emplace_back(&whileOp.getAfter(), whileOp.getAfter().getArguments()); if (!boolAttr || !boolAttr.getValue()) - regions.emplace_back(whileOp.getResults()); + regions.emplace_back(whileOp.getOperation(), whileOp.getResults()); } //===----------------------------------------------------------------------===// @@ -749,7 +750,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) { return dyn_cast_or_null<ForOp>(containingOp); } -OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) { return getInitArgs(); } @@ -759,7 +760,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } @@ -2053,9 +2054,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point, // parallel by multiple threads. We should not expect to branch back into // the forall body after the region's execution is complete. if (point.isParent()) - regions.push_back(RegionSuccessor(&getRegion())); + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); else - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); } //===----------------------------------------------------------------------===// @@ -2333,9 +2335,10 @@ void IfOp::print(OpAsmPrinter &p) { void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { - // The `then` and the `else` region branch back to the parent operation. + // The `then` and the `else` region branch back to the parent operation or one + // of the recursive parent operations (early exit case). if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -2344,7 +2347,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -2361,7 +2365,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } } @@ -3385,7 +3389,8 @@ void ParallelOp::getSuccessorRegions( // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion())); - regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor( + getOperation(), ResultRange{getResults().end(), getResults().end()})); } //===----------------------------------------------------------------------===// @@ -3431,7 +3436,7 @@ LogicalResult ReduceOp::verifyRegions() { } MutableOperandRange -ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { +ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) { // No operands are forwarded to the next iteration. return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); } @@ -3514,8 +3519,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() { return getBeforeArguments(); } -OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBefore() && +OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBefore() && "WhileOp is expected to branch only to the first region"); return getInits(); } @@ -3528,15 +3533,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point, return; } - assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && + assert(llvm::is_contained( + {&getAfter(), &getBefore()}, + point.getTerminatorPredecessorOrNull()->getParentRegion()) && "there are only two regions in a WhileOp"); // The body region always branches back to the condition region. - if (point == getAfter()) { + if (point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getAfter()) { regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); regions.emplace_back(&getAfter(), getAfter().getArguments()); } @@ -4445,7 +4453,7 @@ void IndexSwitchOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { // All regions branch back to the parent op. if (!point.isParent()) { - successors.emplace_back(getResults()); + successors.emplace_back(getOperation(), getResults()); return; } diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index ae52af5..ddcbda8 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -23,7 +23,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::ForOp; using scf::WhileOp; diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp index a2f03f1..00bef70 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -21,7 +21,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::LoopNest; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 5ba8289..f0f22e5 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -346,7 +346,7 @@ void AssumingOp::getSuccessorRegions( // parent, so return the correct RegionSuccessor purely based on the index // being None or 0. if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 1a9d9e1..3962e3e 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2597,7 +2597,7 @@ std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); } -OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) { return getInitArgs(); } @@ -2607,7 +2607,7 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point, // or back into the operation itself. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); // It is possible for loop not to enter the body. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index f53d272..ffa8b40 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -152,19 +152,20 @@ IterationGraphSorter IterationGraphSorter::fromGenericOp( } IterationGraphSorter::IterationGraphSorter( - SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out, - AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes, + SmallVector<Value> &&insArg, SmallVector<AffineMap> &&loop2InsLvlArg, + Value out, AffineMap loop2OutLvl, + SmallVector<utils::IteratorType> &&iterTypesArg, sparse_tensor::LoopOrderingStrategy strategy) - : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out), - loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), + : ins(std::move(insArg)), loop2InsLvl(std::move(loop2InsLvlArg)), out(out), + loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypesArg)), strategy(strategy) { // One map per tensor. - assert(this->loop2InsLvl.size() == this->ins.size()); + assert(loop2InsLvl.size() == ins.size()); // All the affine maps have the same number of dimensions (loops). assert(llvm::all_equal(llvm::map_range( - this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); + loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); // The number of results of the map should match the rank of the tensor. - assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) { + assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) { auto [m, v] = mvPair; // For ranked types the rank must match. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h index b2a16e9..35e58ed 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h @@ -59,10 +59,10 @@ public: private: // Private constructor. - IterationGraphSorter(SmallVector<Value> &&ins, - SmallVector<AffineMap> &&loop2InsLvl, Value out, + IterationGraphSorter(SmallVector<Value> &&insArg, + SmallVector<AffineMap> &&loop2InsLvlArg, Value out, AffineMap loop2OutLvl, - SmallVector<utils::IteratorType> &&iterTypes, + SmallVector<utils::IteratorType> &&iterTypesArg, sparse_tensor::LoopOrderingStrategy strategy = sparse_tensor::LoopOrderingStrategy::kDefault); diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp index 1e3b377..549ac7a 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -77,7 +77,7 @@ FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer( dyn_cast<TilingInterface>(consumerOperands.front()->getOwner()); if (!consumerOp) return failure(); - for (auto opOperand : consumerOperands.drop_front()) { + for (auto *opOperand : consumerOperands.drop_front()) { if (opOperand->getOwner() != consumerOp) { LLVM_DEBUG({ llvm::dbgs() diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 365afab..062606e 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -96,9 +96,9 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, // AlternativesOp //===----------------------------------------------------------------------===// -OperandRange -transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { - if (!point.isParent() && getOperation()->getNumOperands() == 1) +OperandRange transform::AlternativesOp::getEntrySuccessorOperands( + RegionSuccessor successor) { + if (!successor.isParent() && getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); @@ -107,15 +107,18 @@ transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { void transform::AlternativesOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { for (Region &alternative : llvm::drop_begin( - getAlternatives(), - point.isParent() ? 0 - : point.getRegionOrNull()->getRegionNumber() + 1)) { + getAlternatives(), point.isParent() + ? 0 + : point.getTerminatorPredecessorOrNull() + ->getParentRegion() + ->getRegionNumber() + + 1)) { regions.emplace_back(&alternative, !getOperands().empty() ? alternative.getArguments() : Block::BlockArgListType()); } if (!point.isParent()) - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::AlternativesOp::getRegionInvocationBounds( @@ -1740,16 +1743,18 @@ void transform::ForeachOp::getSuccessorRegions( } // Branch back to the region or the parent. - assert(point == getBody() && "unexpected region index"); + assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBody() && + "unexpected region index"); regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } OperandRange -transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { +transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) { // Each block argument handle is mapped to a subset (one op to be precise) // of the payload of the corresponding `targets` operand of ForeachOp. - assert(point == getBody() && "unexpected region index"); + assert(successor.getSuccessor() == &getBody() && "unexpected region index"); return getOperation()->getOperands(); } @@ -2948,8 +2953,8 @@ void transform::SequenceOp::getEffects( } OperandRange -transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBody() && "unexpected region index"); +transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBody() && "unexpected region index"); if (getOperation()->getNumOperands() > 0) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), @@ -2966,8 +2971,10 @@ void transform::SequenceOp::getSuccessorRegions( return; } - assert(point == getBody() && "unexpected region index"); - regions.emplace_back(getOperation()->getResults()); + assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBody() && + "unexpected region index"); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::SequenceOp::getRegionInvocationBounds( diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index c627158..f727118 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" @@ -112,7 +113,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, } OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { // No operands will be forwarded to the region(s). return getOperands().slice(0, 0); } @@ -128,7 +129,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions( for (Region &alternative : getAlternatives()) regions.emplace_back(&alternative, Block::BlockArgListType()); else - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::tune::AlternativesOp::getRegionInvocationBounds( |
