diff options
Diffstat (limited to 'mlir/lib/Dialect')
26 files changed, 374 insertions, 134 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 4c4965e..df955fc 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() { if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { return emitOpError( - "source element types much match (except for fp8) but have ") + "source element types must match (except for fp8/bf8) but have ") << sourceAType << " and " << sourceBType; } - if (!sourceAElemType.isInteger(4) && getK() != 16) { - return emitOpError("K dimension must be 16 for source element type ") - << sourceAElemType; + if (isSrcFloat) { + if (getClamp()) + return emitOpError("clamp flag is not supported for float types"); + if (getUnsignedA() || getUnsignedB()) + return emitOpError("unsigned flags are not supported for float types"); } return success(); } @@ -422,11 +424,11 @@ LogicalResult MFMAOp::verify() { Type sourceElem = sourceType, destElem = destType; uint32_t sourceLen = 1, destLen = 1; - if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) { + if (auto sourceVector = dyn_cast<VectorType>(sourceType)) { sourceLen = sourceVector.getNumElements(); sourceElem = sourceVector.getElementType(); } - if (auto destVector = llvm::dyn_cast<VectorType>(destType)) { + if (auto destVector = dyn_cast<VectorType>(destType)) { destLen = destVector.getNumElements(); destElem = destVector.getElementType(); } @@ -451,7 +453,7 @@ LogicalResult MFMAOp::verify() { return emitOpError("expected both non-small-float source operand types " "to match exactly"); } - // Normalize the wider integer types the compiler expects to i8 + // Normalize the wider integer types the compiler expects to i8. if (sourceElem.isInteger(32)) { sourceLen *= 4; sourceElem = b.getI8Type(); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index e0a53cd..0c35921 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2716,8 +2716,9 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor, return success(folded); } -OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getRegion()) && "invalid region point"); +OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert((successor.isParent() || successor.getSuccessor() == &getRegion()) && + "invalid region point"); // The initial operands map to the loop arguments after the induction // variable or are forwarded to the results when the trip count is zero. @@ -2726,34 +2727,41 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { void AffineForOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { - assert((point.isParent() || point == getRegion()) && "expected loop region"); + assert((point.isParent() || + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getRegion()) && + "expected loop region"); // The loop may typically branch back to its body or to the parent operation. // If the predecessor is the parent op and the trip count is known to be at // least one, branch into the body using the iterator arguments. And in cases // we know the trip count is zero, it can only branch back to its parent. std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this); - if (point.isParent() && tripCount.has_value()) { - if (tripCount.value() > 0) { - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - return; - } - if (tripCount.value() == 0) { - regions.push_back(RegionSuccessor(getResults())); - return; + if (tripCount.has_value()) { + if (!point.isParent()) { + // From the loop body, if the trip count is one, we can only branch back + // to the parent. + if (tripCount == 1) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } + if (tripCount == 0) + return; + } else { + if (tripCount.value() > 0) { + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + return; + } + if (tripCount.value() == 0) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } } } - // From the loop body, if the trip count is one, we can only branch back to - // the parent. - if (!point.isParent() && tripCount == 1) { - regions.push_back(RegionSuccessor(getResults())); - return; - } - // In all other cases, the loop may branch back to itself or the parent // operation. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } AffineBound AffineForOp::getLowerBound() { @@ -3142,7 +3150,7 @@ void AffineIfOp::getSuccessorRegions( RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); // If the "else" region is empty, branch bach into parent. if (getElseRegion().empty()) { - regions.push_back(getResults()); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } else { regions.push_back( RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); @@ -3152,7 +3160,7 @@ void AffineIfOp::getSuccessorRegions( // If the predecessor is the `else`/`then` region, then branching into parent // op is valid. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } LogicalResult AffineIfOp::verify() { diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index 50a0f3d..e08cc6f 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -978,12 +978,11 @@ static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp, LLVM_DEBUG( dbgs() << "\n[early-vect]+++++ affine.apply on vector operand\n"); return nullptr; - } else { - Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand); - if (!updatedOperand) - updatedOperand = operand; - updatedOperands.push_back(updatedOperand); } + Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand); + if (!updatedOperand) + updatedOperand = operand; + updatedOperands.push_back(updatedOperand); } auto newApplyOp = AffineApplyOp::create( diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index d925c19..a651710 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -216,8 +216,8 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { for (auto condBranch : worklist) { auto loc = condBranch.getLoc(); Block *block = condBranch->getBlock(); - auto newTrueBranch = rewriter.splitBlock(block, block->end()); - auto newFalseBranch = rewriter.splitBlock(block, block->end()); + auto *newTrueBranch = rewriter.splitBlock(block, block->end()); + auto *newFalseBranch = rewriter.splitBlock(block, block->end()); insertJump(loc, newTrueBranch, condBranch.getTrueDest(), condBranch.getTrueDestOperands()); insertJump(loc, newFalseBranch, condBranch.getFalseDest(), @@ -382,7 +382,7 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, // Find or create a live range for `value`. auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator); LiveRange &valueLiveRange = it->second; - auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); + auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); // Add the interval [firstUseOrDef, lastUseInBlock) to the live range. unsigned startOpIdx = operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index dc7b07d..8e4a49d 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -36,8 +36,9 @@ void AsyncDialect::initialize() { constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes"; -OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBodyRegion() && "invalid region index"); +OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBodyRegion() && + "invalid region index"); return getBodyOperands(); } @@ -53,8 +54,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // The `body` region branch back to the parent operation. - if (point == getBodyRegion()) { - regions.push_back(RegionSuccessor(getBodyResults())); + if (!point.isParent() && + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBodyRegion()) { + regions.push_back(RegionSuccessor(getOperation(), getBodyResults())); return; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index b593cca..36a759c 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -562,8 +562,11 @@ LogicalResult BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) { SmallVector<TypeRange> returnOperandTypes(llvm::map_range( op.getFunctionBody().getOps<RegionBranchTerminatorOpInterface>(), - [](RegionBranchTerminatorOpInterface op) { - return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes(); + [&](RegionBranchTerminatorOpInterface branchOp) { + return branchOp + .getSuccessorOperands(RegionSuccessor( + op.getOperation(), op.getOperation()->getResults())) + .getTypes(); })); if (!llvm::all_equal(returnOperandTypes)) return op->emitError( @@ -942,8 +945,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { // about, but we would need to check how many successors there are and under // which condition they are taken, etc. - MutableOperandRange operands = - op.getMutableSuccessorOperands(RegionBranchPoint::parent()); + MutableOperandRange operands = op.getMutableSuccessorOperands( + RegionSuccessor(op.getOperation(), op.getOperation()->getResults())); SmallVector<Value> updatedOwnerships; auto result = deallocation_impl::insertDeallocOpForReturnLike( diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 4754f0b..0992ce14 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -845,7 +845,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); return; } @@ -854,7 +855,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -871,7 +873,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index b5f8dda..6c6d8d2 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, void WarpExecuteOnLane0Op::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index c551fba..1c21a2f 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { void AllocaScopeOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index a15bf89..3aa801b 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -66,7 +66,7 @@ struct ExpandShapeOpInterface ValueBoundsConstraintSet &cstr) const { auto expandOp = cast<memref::ExpandShapeOp>(op); assert(value == expandOp.getResult() && "invalid value"); - cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim]; } }; @@ -98,6 +98,27 @@ struct RankOpInterface } }; +struct CollapseShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface, + memref::CollapseShapeOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto collapseOp = cast<memref::CollapseShapeOp>(op); + assert(value == collapseOp.getResult() && "invalid value"); + + // Multiply the expressions for the dimensions in the reassociation group. + const ReassociationIndices &reassocIndices = + collapseOp.getReassociationIndices()[dim]; + AffineExpr productExpr = + cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]); + for (size_t i = 1; i < reassocIndices.size(); ++i) { + productExpr = + productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]); + } + cstr.bound(value)[dim] == productExpr; + } +}; + struct SubViewOpInterface : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface, SubViewOp> { @@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface<memref::AllocaOp>>(*ctx); memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); + memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>( + *ctx); memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>( *ctx); memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 291da1f..14152c5 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" using namespace mlir; @@ -273,7 +274,9 @@ struct SubViewOpInterface Value one = arith::ConstantIndexOp::create(builder, loc, 1); auto metadataOp = ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); - for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { + for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) { + // Reset insertion point to before the operation for each dimension + builder.setInsertionPoint(subView); Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp(builder, loc, @@ -290,6 +293,16 @@ struct SubViewOpInterface std::to_string(i) + " is out-of-bounds")); + // Only verify if size > 0 + Value sizeIsNonZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::sgt, size, zero); + + auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), + sizeIsNonZero, /*withElseRegion=*/true); + + // Populate the "then" region (for size > 0). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = @@ -298,8 +311,20 @@ struct SubViewOpInterface arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); + + scf::YieldOp::create(builder, loc, lastPosInBounds); + + // Populate the "else" region (for size == 0). + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + Value trueVal = + arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); + scf::YieldOp::create(builder, loc, trueVal); + + builder.setInsertionPointAfter(ifOp); + Value finalCondition = ifOp.getResult(0); + cf::AssertOp::create( - builder, loc, lastPosInBounds, + builder, loc, finalCondition, generateErrorMessage(op, "subview runs out-of-bounds along dimension " + std::to_string(i))); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 744a595..2946b53 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -111,10 +111,8 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, return nullptr; } -/// Helper function to compute the difference between two values. This is used -/// by the loop implementations to compute the trip count. -static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub, - bool isSigned) { +std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub, + bool isSigned) { llvm::APSInt diff; auto addOp = ub.getDefiningOp<arith::AddIOp>(); if (!addOp) @@ -399,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions( } // Otherwise, the region branches back to the parent operation. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } //===----------------------------------------------------------------------===// @@ -407,10 +405,11 @@ void ExecuteRegionOp::getSuccessorRegions( //===----------------------------------------------------------------------===// MutableOperandRange -ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getParentOp().getAfter()) && - "condition op can only exit the loop or branch to the after" - "region"); +ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) { + assert( + (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) && + "condition op can only exit the loop or branch to the after" + "region"); // Pass all operands except the condition to the successor region. return getArgsMutable(); } @@ -428,7 +427,7 @@ void ConditionOp::getSuccessorRegions( regions.emplace_back(&whileOp.getAfter(), whileOp.getAfter().getArguments()); if (!boolAttr || !boolAttr.getValue()) - regions.emplace_back(whileOp.getResults()); + regions.emplace_back(whileOp.getOperation(), whileOp.getResults()); } //===----------------------------------------------------------------------===// @@ -751,7 +750,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) { return dyn_cast_or_null<ForOp>(containingOp); } -OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) { return getInitArgs(); } @@ -761,7 +760,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } @@ -2055,9 +2054,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point, // parallel by multiple threads. We should not expect to branch back into // the forall body after the region's execution is complete. if (point.isParent()) - regions.push_back(RegionSuccessor(&getRegion())); + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); else - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); } //===----------------------------------------------------------------------===// @@ -2335,9 +2335,10 @@ void IfOp::print(OpAsmPrinter &p) { void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { - // The `then` and the `else` region branch back to the parent operation. + // The `then` and the `else` region branch back to the parent operation or one + // of the recursive parent operations (early exit case). if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -2346,7 +2347,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -2363,7 +2365,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } } @@ -3387,7 +3389,8 @@ void ParallelOp::getSuccessorRegions( // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion())); - regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor( + getOperation(), ResultRange{getResults().end(), getResults().end()})); } //===----------------------------------------------------------------------===// @@ -3433,7 +3436,7 @@ LogicalResult ReduceOp::verifyRegions() { } MutableOperandRange -ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { +ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) { // No operands are forwarded to the next iteration. return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); } @@ -3516,8 +3519,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() { return getBeforeArguments(); } -OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBefore() && +OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBefore() && "WhileOp is expected to branch only to the first region"); return getInits(); } @@ -3530,15 +3533,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point, return; } - assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && + assert(llvm::is_contained( + {&getAfter(), &getBefore()}, + point.getTerminatorPredecessorOrNull()->getParentRegion()) && "there are only two regions in a WhileOp"); // The body region always branches back to the condition region. - if (point == getAfter()) { + if (point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getAfter()) { regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); regions.emplace_back(&getAfter(), getAfter().getArguments()); } @@ -4447,7 +4453,7 @@ void IndexSwitchOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { // All regions branch back to the parent op. if (!point.isParent()) { - successors.emplace_back(getResults()); + successors.emplace_back(getOperation(), getResults()); return; } diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index ae52af5..ddcbda8 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -23,7 +23,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::ForOp; using scf::WhileOp; diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp index a2f03f1..00bef70 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -21,7 +21,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::LoopNest; diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 10eae89..888dd44 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -291,47 +291,61 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, return arith::DivUIOp::create(builder, loc, sum, divisor); } -/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with -/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap -/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each -/// unrolled iteration using annotateFn. -static void generateUnrolledLoop( - Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, +void mlir::generateUnrolledLoop( + Block *loopBodyBlock, Value iv, uint64_t unrollFactor, function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn, - ValueRange iterArgs, ValueRange yieldedValues) { + ValueRange iterArgs, ValueRange yieldedValues, + IRMapping *clonedToSrcOpsMap) { + + // Check if the op was cloned from another source op, and return it if found + // (or the same op if not found) + auto findOriginalSrcOp = + [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * { + Operation *srcOp = op; + // If the source op derives from another op: traverse the chain to find the + // original source op + while (srcOp && clonedToSrcOpsMap.contains(srcOp)) + srcOp = clonedToSrcOpsMap.lookup(srcOp); + return srcOp; + }; + // Builder to insert unrolled bodies just before the terminator of the body of - // 'forOp'. + // the loop. auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); - constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; + static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; if (!annotateFn) - annotateFn = defaultAnnotateFn; + annotateFn = noopAnnotateFn; // Keep a pointer to the last non-terminator operation in the original block // so that we know what to clone (since we are doing this in-place). Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2); - // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). + // Unroll the contents of the loop body (append unrollFactor - 1 additional + // copies). SmallVector<Value, 4> lastYielded(yieldedValues); for (unsigned i = 1; i < unrollFactor; i++) { - IRMapping operandMap; - // Prepare operand map. + IRMapping operandMap; operandMap.map(iterArgs, lastYielded); // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forOpIV.use_empty()) { - Value ivUnroll = ivRemapFn(i, forOpIV, builder); - operandMap.map(forOpIV, ivUnroll); + if (!iv.use_empty()) { + Value ivUnroll = ivRemapFn(i, iv, builder); + operandMap.map(iv, ivUnroll); } // Clone the original body of 'forOp'. for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) { - Operation *clonedOp = builder.clone(*it, operandMap); + Operation *srcOp = &(*it); + Operation *clonedOp = builder.clone(*srcOp, operandMap); annotateFn(i, clonedOp, builder); + if (clonedToSrcOpsMap) + clonedToSrcOpsMap->map(clonedOp, + findOriginalSrcOp(srcOp, *clonedToSrcOpsMap)); } // Update yielded values. @@ -1544,3 +1558,100 @@ bool mlir::isPerfectlyNestedForLoops( } return true; } + +llvm::SmallVector<int64_t> +mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) { + std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds(); + std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds(); + std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps(); + if (!loBnds || !upBnds || !steps) + return {}; + llvm::SmallVector<int64_t> tripCounts; + for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) { + std::optional<llvm::APInt> numIter = constantTripCount( + lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb); + if (!numIter) + return {}; + tripCounts.push_back(numIter->getSExtValue()); + } + return tripCounts; +} + +FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors( + scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors, + RewriterBase &rewriter, + function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn, + IRMapping *clonedToSrcOpsMap) { + const unsigned numLoops = op.getNumLoops(); + assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) && + "Expected positive unroll factors"); + assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) && + "Expected non-empty unroll factors of size <= to the number of loops"); + + // Bail out if no valid unroll factors were provided + if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; })) + return rewriter.notifyMatchFailure( + op, "Unrolling not applied if all factors are 1"); + + // Return if the loop body is empty. + if (llvm::hasSingleElement(op.getBody()->getOperations())) + return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body"); + + // If the provided unroll factors do not cover all the loop dims, they are + // applied to the inner loop dimensions. + const unsigned firstLoopDimIdx = numLoops - unrollFactors.size(); + + // Make sure that the unroll factors divide the iteration space evenly + // TODO: Support unrolling loops with dynamic iteration spaces. + const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op); + if (tripCounts.empty()) + return rewriter.notifyMatchFailure( + op, "Failed to compute constant trip counts for the loop. Note that " + "dynamic loop sizes are not supported."); + + for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) { + const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx]; + if (tripCounts[dimIdx] % unrollFactor) + return rewriter.notifyMatchFailure( + op, "Unroll factors don't divide the iteration space evenly"); + } + + std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps(); + if (!maybeFoldSteps) + return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps"); + llvm::SmallVector<size_t> steps{}; + for (auto step : *maybeFoldSteps) + steps.push_back(static_cast<size_t>(*getConstantIntValue(step))); + + for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) { + const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx]; + if (unrollFactor == 1) + continue; + const size_t origStep = steps[dimIdx]; + const int64_t newStep = origStep * unrollFactor; + IRMapping clonedToSrcOpsMap; + + ValueRange iterArgs = ValueRange(op.getRegionIterArgs()); + auto yieldedValues = op.getBody()->getTerminator()->getOperands(); + + generateUnrolledLoop( + op.getBody(), op.getInductionVars()[dimIdx], unrollFactor, + [&](unsigned i, Value iv, OpBuilder b) { + // iv' = iv + step * i; + const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i); + const auto map = + b.getDimIdentityMap().dropResult(0).insertResult(expr, 0); + return affine::AffineApplyOp::create(b, iv.getLoc(), map, + ValueRange{iv}); + }, + /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap); + + // Update loop step + auto prevInsertPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + op.getStepMutable()[dimIdx].assign( + arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep)); + rewriter.restoreInsertionPoint(prevInsertPoint); + } + return op; +} diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index fe50865..0c8114d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1276,12 +1276,19 @@ LogicalResult spirv::GlobalVariableOp::verify() { Operation *initOp = SymbolTable::lookupNearestSymbolFrom( (*this)->getParentOp(), init.getAttr()); // TODO: Currently only variable initialization with specialization - // constants and other variables is supported. They could be normal - // constants in the module scope as well. - if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp, - spirv::SpecConstantCompositeOp>(initOp)) { + // constants is supported. There could be normal constants in the module + // scope as well. + // + // In the current setup we also cannot initialize one global variable with + // another. The problem is that if we try to initialize pointer of type X + // with another pointer type, the validator fails because it expects the + // variable to be initialized to be type X, not pointer to X. Now + // `spirv.GlobalVariable` only allows pointer type, so in the current design + // we cannot initialize one `spirv.GlobalVariable` with another. + if (!initOp || + !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) { return emitOpError("initializer must be result of a " - "spirv.SpecConstant or spirv.GlobalVariable or " + "spirv.SpecConstant or " "spirv.SpecConstantCompositeOp op"); } } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 5ba8289..f0f22e5 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -346,7 +346,7 @@ void AssumingOp::getSuccessorRegions( // parent, so return the correct RegionSuccessor purely based on the index // being None or 0. if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 1a9d9e1..3962e3e 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2597,7 +2597,7 @@ std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); } -OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) { return getInitArgs(); } @@ -2607,7 +2607,7 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point, // or back into the operation itself. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); // It is possible for loop not to enter the body. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index 73e0f3d..f53d272 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter( loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), strategy(strategy) { // One map per tensor. - assert(loop2InsLvl.size() == ins.size()); + assert(this->loop2InsLvl.size() == this->ins.size()); // All the affine maps have the same number of dimensions (loops). assert(llvm::all_equal(llvm::map_range( - loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); + this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); // The number of results of the map should match the rank of the tensor. - assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) { + assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) { auto [m, v] = mvPair; - return m.getNumResults() == cast<ShapedType>(v.getType()).getRank(); + + // For ranked types the rank must match. + // Simply return true for UnrankedTensorType + if (auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) { + return !shapedType.hasRank() || + (m.getNumResults() == shapedType.getRank()); + } + // Non-shaped (scalar) types behave like rank-0. + return m.getNumResults() == 0; })); itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false)); diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp index c031118..753cb95 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" @@ -158,7 +159,11 @@ struct ExtractSliceOpInterface // 0 <= offset + (size - 1) * stride < dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); - for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { + + for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) { + // Reset insertion point to before the operation for each dimension + builder.setInsertionPoint(extractSliceOp); + Value offset = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp( @@ -176,6 +181,16 @@ struct ExtractSliceOpInterface std::to_string(i) + " is out-of-bounds")); + // Only verify if size > 0 + Value sizeIsNonZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::sgt, size, zero); + + auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), + sizeIsNonZero, /*withElseRegion=*/true); + + // Populate the "then" region (for size > 0). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = @@ -184,8 +199,19 @@ struct ExtractSliceOpInterface arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); + scf::YieldOp::create(builder, loc, lastPosInBounds); + + // Populate the "else" region (for size == 0). + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + Value trueVal = + arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); + scf::YieldOp::create(builder, loc, trueVal); + + builder.setInsertionPointAfter(ifOp); + Value finalCondition = ifOp.getResult(0); + cf::AssertOp::create( - builder, loc, lastPosInBounds, + builder, loc, finalCondition, generateErrorMessage( op, "extract_slice runs out-of-bounds along dimension " + std::to_string(i))); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index a85ff10a..293c6af 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -38,7 +38,7 @@ using namespace mlir::tosa; //===----------------------------------------------------------------------===// // Check that the zero point of the tensor and padding operations are aligned. -bool checkMatchingPadConstAndZp(Value padConst, Value zp) { +static bool checkMatchingPadConstAndZp(Value padConst, Value zp) { // Check that padConst is a constant value and a scalar tensor DenseElementsAttr padConstAttr; if (!matchPattern(padConst, m_Constant(&padConstAttr)) || @@ -889,8 +889,9 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, //===----------------------------------------------------------------------===// template <typename IntFolder, typename FloatFolder> -DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, - RankedTensorType returnTy) { +static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, + DenseElementsAttr rhs, + RankedTensorType returnTy) { if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType(); auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 365afab..062606e 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -96,9 +96,9 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, // AlternativesOp //===----------------------------------------------------------------------===// -OperandRange -transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { - if (!point.isParent() && getOperation()->getNumOperands() == 1) +OperandRange transform::AlternativesOp::getEntrySuccessorOperands( + RegionSuccessor successor) { + if (!successor.isParent() && getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); @@ -107,15 +107,18 @@ transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { void transform::AlternativesOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { for (Region &alternative : llvm::drop_begin( - getAlternatives(), - point.isParent() ? 0 - : point.getRegionOrNull()->getRegionNumber() + 1)) { + getAlternatives(), point.isParent() + ? 0 + : point.getTerminatorPredecessorOrNull() + ->getParentRegion() + ->getRegionNumber() + + 1)) { regions.emplace_back(&alternative, !getOperands().empty() ? alternative.getArguments() : Block::BlockArgListType()); } if (!point.isParent()) - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::AlternativesOp::getRegionInvocationBounds( @@ -1740,16 +1743,18 @@ void transform::ForeachOp::getSuccessorRegions( } // Branch back to the region or the parent. - assert(point == getBody() && "unexpected region index"); + assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBody() && + "unexpected region index"); regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } OperandRange -transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { +transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) { // Each block argument handle is mapped to a subset (one op to be precise) // of the payload of the corresponding `targets` operand of ForeachOp. - assert(point == getBody() && "unexpected region index"); + assert(successor.getSuccessor() == &getBody() && "unexpected region index"); return getOperation()->getOperands(); } @@ -2948,8 +2953,8 @@ void transform::SequenceOp::getEffects( } OperandRange -transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBody() && "unexpected region index"); +transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBody() && "unexpected region index"); if (getOperation()->getNumOperands() > 0) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), @@ -2966,8 +2971,10 @@ void transform::SequenceOp::getSuccessorRegions( return; } - assert(point == getBody() && "unexpected region index"); - regions.emplace_back(getOperation()->getResults()); + assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBody() && + "unexpected region index"); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::SequenceOp::getRegionInvocationBounds( diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index c627158..f727118 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" @@ -112,7 +113,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, } OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { // No operands will be forwarded to the region(s). return getOperands().slice(0, 0); } @@ -128,7 +129,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions( for (Region &alternative : getAlternatives()) regions.emplace_back(&alternative, Block::BlockArgListType()); else - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::tune::AlternativesOp::getRegionInvocationBounds( diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 24e9095..f9aa28d5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -113,9 +113,12 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, if (layout.size() != shape.size()) return std::nullopt; auto ratio = computeShapeRatio(shape, layout); - if (!ratio.has_value()) + if (ratio.has_value()) { + newShape = ratio.value(); + } else if (!rr || !computeShapeRatio(layout, shape).has_value()) { return std::nullopt; - newShape = ratio.value(); + } + // Round-robin case: continue with original newShape } if (data.size()) { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 2c37140..ec5feb8 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -344,6 +344,13 @@ void XeGPUBlockingPass::runOnOperation() { xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter); + // Remove leading unit dimensions from vector ops and then + // do the unrolling. + { + RewritePatternSet patterns(ctx); + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + (void)applyPatternsGreedily(op, std::move(patterns)); + } xegpu::UnrollOptions options; options.setFilterConstraint( [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); }); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index b4605cd..a38993e 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -147,7 +147,7 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { } if (auto arg = dyn_cast<BlockArgument>(value)) { - auto parentOp = arg.getOwner()->getParentOp(); + auto *parentOp = arg.getOwner()->getParentOp(); if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) { OpOperand *tiedInit = loop.getTiedLoopInit(arg); if (tiedInit) |
