diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 30 | 
1 files changed, 20 insertions, 10 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 3487e81..7e70ba2 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -245,11 +245,14 @@ raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {  } // namespace -static bool isUniformShape(Value *V) { +static bool isShapePreserving(Value *V) {    Instruction *I = dyn_cast<Instruction>(V);    if (!I)      return true; +  if (isa<SelectInst>(I)) +    return true; +    if (I->isBinaryOp())      return true; @@ -300,6 +303,16 @@ static bool isUniformShape(Value *V) {    }  } +/// Return an iterator over the operands of \p I that should share shape +/// information with \p I. +static iterator_range<Use *> getShapedOperandsForInst(Instruction *I) { +  assert(isShapePreserving(I) && +         "Can't retrieve shaped operands for an instruction that does not " +         "preserve shape information"); +  auto Ops = I->operands(); +  return isa<SelectInst>(I) ? drop_begin(Ops) : Ops; +} +  /// Return the ShapeInfo for the result of \p I, it it can be determined.  static std::optional<ShapeInfo>  computeShapeInfoForInst(Instruction *I, @@ -329,9 +342,8 @@ computeShapeInfoForInst(Instruction *I,        return OpShape->second;    } -  if (isUniformShape(I) || isa<SelectInst>(I)) { -    auto Ops = I->operands(); -    auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops; +  if (isShapePreserving(I)) { +    auto ShapedOps = getShapedOperandsForInst(I);      // Find the first operand that has a known shape and use that.      for (auto &Op : ShapedOps) {        auto OpShape = ShapeMap.find(Op.get()); @@ -710,10 +722,9 @@ public:        case Intrinsic::matrix_column_major_store:          return true;        default: -        return isUniformShape(II); +        break;        } -    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) || -           isa<SelectInst>(V); +    return isShapePreserving(V) || isa<StoreInst>(V) || isa<LoadInst>(V);    }    /// Propagate the shape information of instructions to their users. @@ -800,9 +811,8 @@ public:        } else if (isa<StoreInst>(V)) {          // Nothing to do.  We forward-propagated to this so we would just          // backward propagate to an instruction with an already known shape. -      } else if (isUniformShape(V) || isa<SelectInst>(V)) { -        auto Ops = cast<Instruction>(V)->operands(); -        auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops; +      } else if (isShapePreserving(V)) { +        auto ShapedOps = getShapedOperandsForInst(cast<Instruction>(V));          // Propagate to all operands.          ShapeInfo Shape = ShapeMap[V];          for (Use &U : ShapedOps) {  | 
