diff options
Diffstat (limited to 'mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp')
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index a15bf89..3aa801b 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -66,7 +66,7 @@ struct ExpandShapeOpInterface ValueBoundsConstraintSet &cstr) const { auto expandOp = cast<memref::ExpandShapeOp>(op); assert(value == expandOp.getResult() && "invalid value"); - cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim]; } }; @@ -98,6 +98,27 @@ struct RankOpInterface } }; +struct CollapseShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface, + memref::CollapseShapeOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto collapseOp = cast<memref::CollapseShapeOp>(op); + assert(value == collapseOp.getResult() && "invalid value"); + + // Multiply the expressions for the dimensions in the reassociation group. + const ReassociationIndices &reassocIndices = + collapseOp.getReassociationIndices()[dim]; + AffineExpr productExpr = + cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]); + for (size_t i = 1; i < reassocIndices.size(); ++i) { + productExpr = + productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]); + } + cstr.bound(value)[dim] == productExpr; + } +}; + struct SubViewOpInterface : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface, SubViewOp> { @@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface<memref::AllocaOp>>(*ctx); memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); + memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>( + *ctx); memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>( *ctx); memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); |
