diff options
| -rw-r--r-- | mlir/include/mlir/IR/Value.h | 10 | ||||
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Shard/Transforms/Partition.cpp | 16 |
3 files changed, 19 insertions, 11 deletions
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 4d6d89f..af58778 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -433,9 +433,19 @@ inline unsigned OpResultImpl::getResultNumber() const { template <typename Ty> struct TypedValue : Value { using Value::Value; + using ValueType = Ty; static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); } + /// TypedValue<B> can implicitly convert to TypedValue<A> if B is assignable + /// to A. + template <typename ToTy, + typename = typename std::enable_if<std::is_assignable< + typename ToTy::ValueType &, Ty>::value>::type> + operator ToTy() const { + return llvm::cast<ToTy>(*this); + } + /// Return the known Type Ty getType() const { return llvm::cast<Ty>(Value::getType()); } void setType(Ty ty) { Value::setType(ty); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index c06a48e..c551fba 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1751,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { } TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() { - return cast<TypedValue<PtrLikeTypeInterface>>(getSource()); + return getSource(); } TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() { - return cast<TypedValue<PtrLikeTypeInterface>>(getDest()); + return getDest(); } bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt, diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 5dc61a2..335ca1a 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue<ShapedType> sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) { - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( + TypedValue<ShapedType> targetShard = AllSliceOp::create(builder, sourceShard, grid, ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis) - .getResult()); + .getResult(); Sharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis); return {targetShard, targetSharding}; @@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding( APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - tensor::CastOp::create(builder, targetShape, allGatherResult) - .getResult()); + TypedValue<ShapedType> targetShard = + tensor::CastOp::create(builder, targetShape, allGatherResult).getResult(); return {targetShard, targetSharding}; } @@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - tensor::CastOp::create(builder, targetShape, allToAllResult).getResult()); + TypedValue<ShapedType> targetShard = + tensor::CastOp::create(builder, targetShape, allToAllResult).getResult(); return {targetShard, targetSharding}; } @@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source, auto targetSharding = target.getSharding(); ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding, - cast<TypedValue<ShapedType>>(source.getSrc()), - sourceShardValue); + source.getSrc(), sourceShardValue); } TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, |
