diff options
author | Ian Wood <ianwood2024@u.northwestern.edu> | 2025-07-24 08:07:51 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-24 16:07:51 +0100 |
commit | 3ebe5d661f7829b2ffe1b422ec7d00d3213c9730 (patch) | |
tree | 542d2f32145cccb12606c314459517affd32aceb | |
parent | 8f8b436c2b914a8abcee12b8a3bf45aec9fa627e (diff) | |
download | llvm-3ebe5d661f7829b2ffe1b422ec7d00d3213c9730.zip llvm-3ebe5d661f7829b2ffe1b422ec7d00d3213c9730.tar.gz llvm-3ebe5d661f7829b2ffe1b422ec7d00d3213c9730.tar.bz2 |
[mlir][linalg] Drop unit dims on IndexingMapOpInterface (#150280)
Generalizes `dropUnitDims` to operate on any op implementing the
`IndexingMapOpInterface`. Operation specific creation is handled by
passing a builder that will construct the new operation based on the
dropped dimensions.
---------
Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
Co-authored-by: Kunwar Grover <groverkss@gmail.com>
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 12 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 121 |
2 files changed, 88 insertions, 45 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 38e5364..e625eef 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -537,10 +537,20 @@ struct ControlDropUnitDims { return SmallVector<unsigned>{}; }; }; + struct DropUnitDimsResult { - linalg::GenericOp resultOp; + IndexingMapOpInterface resultOp; SmallVector<Value> replacements; }; +using DroppedUnitDimsBuilder = std::function<IndexingMapOpInterface( + Location loc, OpBuilder &, IndexingMapOpInterface, + ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps, + const llvm::SmallDenseSet<unsigned> &droppedDims)>; + +FailureOr<DropUnitDimsResult> +dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, + const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, + const ControlDropUnitDims &options); FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index e0062d1..6c59cd6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -331,14 +331,14 @@ struct UnitExtentReplacementInfo { SmallVector<int64_t> targetShape; }; static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( - MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, + MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand, llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap, ArrayRef<AffineExpr> dimReplacements) { UnitExtentReplacementInfo info; ReassociationIndices reassociationGroup; SmallVector<AffineExpr> newIndexExprs; - AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); - ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand); + AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); + SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand); ArrayRef<AffineExpr> exprs = indexingMap.getResults(); auto isUnitDim = [&](unsigned dim) { @@ -380,9 +380,16 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( } FailureOr<DropUnitDimsResult> -linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, +linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, + const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, const ControlDropUnitDims &options) { - SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); + auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation()); + if (!dpsOp) { + return rewriter.notifyMatchFailure( + op, "op should implement DestinationStyleOpInterface"); + } + + SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray(); if (indexingMaps.empty()) return failure(); @@ -392,19 +399,19 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext())); if (!invertedMap) { - return rewriter.notifyMatchFailure(genericOp, + return rewriter.notifyMatchFailure(op, "invalid indexing maps for operation"); } SmallVector<int64_t> allShapesSizes; - for (OpOperand &opOperand : genericOp->getOpOperands()) - llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand)); + for (OpOperand &opOperand : op->getOpOperands()) + llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand)); // 1a. Get the allowed list of dimensions to drop from the `options`. - SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp); + SmallVector<unsigned> allowedUnitDims = options.controlFn(op); if (allowedUnitDims.empty()) { return rewriter.notifyMatchFailure( - genericOp, "control function returns no allowed unit dims to prune"); + op, "control function returns no allowed unit dims to prune"); } llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(), allowedUnitDims.end()); @@ -417,19 +424,16 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, } } - // 2. Compute the iterator types of the modified op by dropping the one-trip + // 2. Compute the new loops of the modified op by dropping the one-trip // count loops. - SmallVector<utils::IteratorType> newIteratorTypes; llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap; SmallVector<AffineExpr> dimReplacements; unsigned newDims = 0; - for (auto [index, attr] : - llvm::enumerate(genericOp.getIteratorTypesArray())) { + for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) { if (unitDims.count(index)) { dimReplacements.push_back( getAffineConstantExpr(0, rewriter.getContext())); } else { - newIteratorTypes.push_back(attr); oldDimToNewDimMap[index] = newDims; dimReplacements.push_back( getAffineDimExpr(newDims, rewriter.getContext())); @@ -462,9 +466,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, } return false; }; - for (OpOperand &opOperand : genericOp->getOpOperands()) { - auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand); - ArrayRef<int64_t> shape = genericOp.getShape(&opOperand); + for (OpOperand &opOperand : op->getOpOperands()) { + auto indexingMap = op.getMatchingIndexingMap(&opOperand); + SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand); if (!hasCollapsibleType(opOperand)) { AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols( dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0); @@ -474,9 +478,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, reassociations.push_back({}); continue; } - auto replacementInfo = dropUnitExtentFromOperandMetadata( - rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap, - dimReplacements); + auto replacementInfo = + dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand, + oldDimToNewDimMap, dimReplacements); reassociations.push_back(replacementInfo.reassociation); newIndexingMaps.push_back(replacementInfo.indexMap); targetShapes.push_back(replacementInfo.targetShape); @@ -491,13 +495,13 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, concatAffineMaps(newIndexingMaps, rewriter.getContext()))) return failure(); - Location loc = genericOp.getLoc(); + Location loc = op.getLoc(); // 4. For each of the operands, collapse the operand to convert // from original shape to shape in the modified operation if needed, // either through use of reshapes or rank-reducing slices as // specified in `options`. SmallVector<Value> newOperands; - for (OpOperand &opOperand : genericOp->getOpOperands()) { + for (OpOperand &opOperand : op->getOpOperands()) { int64_t idx = opOperand.getOperandNumber(); if (!collapsed[idx]) { newOperands.push_back(opOperand.get()); @@ -508,31 +512,15 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, options.rankReductionStrategy)); } - // 5. Create the `linalg.generic` operation with the new operands, - // indexing maps, iterator types and result types. - ArrayRef<Value> newInputs = - ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs()); - ArrayRef<Value> newOutputs = - ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits()); - SmallVector<Type> resultTypes; - resultTypes.reserve(genericOp.getNumResults()); - for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) - resultTypes.push_back(newOutputs[i].getType()); - GenericOp replacementOp = - rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs, - newIndexingMaps, newIteratorTypes); - rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), - replacementOp.getRegion().begin()); - // 5a. Replace `linalg.index` operations that refer to the dropped unit - // dimensions. - replaceUnitDimIndexOps(replacementOp, unitDims, rewriter); + IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder( + loc, rewriter, op, newOperands, newIndexingMaps, unitDims); // 6. If any result type changes, insert a reshape/slice to convert from the // original type to the new type. SmallVector<Value> resultReplacements; - for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) { - unsigned opOperandIndex = index + replacementOp.getNumDpsInputs(); - Value origDest = genericOp.getDpsInitOperand(index)->get(); + for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) { + unsigned opOperandIndex = index + dpsOp.getNumDpsInputs(); + Value origDest = dpsOp.getDpsInitOperand(index)->get(); if (!collapsed[opOperandIndex]) { resultReplacements.push_back(result); continue; @@ -546,6 +534,51 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, return DropUnitDimsResult{replacementOp, resultReplacements}; } +FailureOr<DropUnitDimsResult> +linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, + const ControlDropUnitDims &options) { + + DroppedUnitDimsBuilder build = + [](Location loc, OpBuilder &b, IndexingMapOpInterface op, + ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps, + const llvm::SmallDenseSet<unsigned> &droppedDims) + -> IndexingMapOpInterface { + auto genericOp = cast<GenericOp>(op); + // Compute the iterator types of the modified op by dropping the one-trip + // count loops. + SmallVector<utils::IteratorType> newIteratorTypes; + for (auto [index, attr] : + llvm::enumerate(genericOp.getIteratorTypesArray())) { + if (!droppedDims.count(index)) + newIteratorTypes.push_back(attr); + } + + // Create the `linalg.generic` operation with the new operands, + // indexing maps, iterator types and result types. + ArrayRef<Value> newInputs = + ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs()); + ArrayRef<Value> newOutputs = + ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits()); + SmallVector<Type> resultTypes; + resultTypes.reserve(genericOp.getNumResults()); + for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) + resultTypes.push_back(newOutputs[i].getType()); + GenericOp replacementOp = + b.create<GenericOp>(loc, resultTypes, newInputs, newOutputs, + newIndexingMaps, newIteratorTypes); + b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), + replacementOp.getRegion().begin()); + // 5a. Replace `linalg.index` operations that refer to the dropped unit + // dimensions. + IRRewriter rewriter(b); + replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter); + + return replacementOp; + }; + + return dropUnitDims(rewriter, genericOp, build, options); +} + namespace { struct DropUnitDims : public OpRewritePattern<GenericOp> { DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {}, |