aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/VE/VVPISelLowering.cpp
blob: f1e2d7f717016bd5a524a4664125b9bd05f59fee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
//===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the lowering and legalization of vector instructions to
// VVP_*layer SDNodes.
//
//===----------------------------------------------------------------------===//

#include "VECustomDAG.h"
#include "VEISelLowering.h"

using namespace llvm;

#define DEBUG_TYPE "ve-lower"

SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
                                              SelectionDAG &DAG) const {
  VECustomDAG CDAG(DAG, Op);
  SDValue AVL =
      CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);
  SDValue A = Op->getOperand(0);
  SDValue B = Op->getOperand(1);
  SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);
  SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);
  SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);
  SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);
  unsigned Opc = Op.getOpcode();
  auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});
  auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});
  return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);
}

SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
  // Can we represent this as a VVP node.
  const unsigned Opcode = Op->getOpcode();
  auto VVPOpcodeOpt = getVVPOpcode(Opcode);
  if (!VVPOpcodeOpt)
    return SDValue();
  unsigned VVPOpcode = *VVPOpcodeOpt;
  const bool FromVP = ISD::isVPOpcode(Opcode);

  // The representative and legalized vector type of this operation.
  VECustomDAG CDAG(DAG, Op);
  // Dispatch to complex lowering functions.
  switch (VVPOpcode) {
  case VEISD::VVP_LOAD:
  case VEISD::VVP_STORE:
    return lowerVVP_LOAD_STORE(Op, CDAG);
  case VEISD::VVP_GATHER:
  case VEISD::VVP_SCATTER:
    return lowerVVP_GATHER_SCATTER(Op, CDAG);
  }

  EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());
  EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);
  auto Packing = getTypePacking(LegalVecVT.getSimpleVT());

  SDValue AVL;
  SDValue Mask;

  if (FromVP) {
    // All upstream VP SDNodes always have a mask and avl.
    auto MaskIdx = ISD::getVPMaskIdx(Opcode);
    auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
    if (MaskIdx)
      Mask = Op->getOperand(*MaskIdx);
    if (AVLIdx)
      AVL = Op->getOperand(*AVLIdx);
  }

  // Materialize default mask and avl.
  if (!AVL)
    AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);
  if (!Mask)
    Mask = CDAG.getConstantMask(Packing, true);

  assert(LegalVecVT.isSimple());
  if (isVVPUnaryOp(VVPOpcode))
    return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});
  if (isVVPBinaryOp(VVPOpcode))
    return CDAG.getNode(VVPOpcode, LegalVecVT,
                        {Op->getOperand(0), Op->getOperand(1), Mask, AVL});
  if (isVVPReductionOp(VVPOpcode)) {
    auto SrcHasStart = hasReductionStartParam(Op->getOpcode());
    SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();
    SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);
    return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,
                                       VectorV, Mask, AVL, Op->getFlags());
  }

  switch (VVPOpcode) {
  default:
    llvm_unreachable("lowerToVVP called for unexpected SDNode.");
  case VEISD::VVP_FFMA: {
    // VE has a swizzled operand order in FMA (compared to LLVM IR and
    // SDNodes).
    auto X = Op->getOperand(2);
    auto Y = Op->getOperand(0);
    auto Z = Op->getOperand(1);
    return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});
  }
  case VEISD::VVP_SELECT: {
    auto Mask = Op->getOperand(0);
    auto OnTrue = Op->getOperand(1);
    auto OnFalse = Op->getOperand(2);
    return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});
  }
  case VEISD::VVP_SETCC: {
    EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
    auto LHS = Op->getOperand(0);
    auto RHS = Op->getOperand(1);
    auto Pred = Op->getOperand(2);
    return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});
  }
  }
}

SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
                                              VECustomDAG &CDAG) const {
  auto VVPOpc = *getVVPOpcode(Op->getOpcode());
  const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);

  // Shares.
  SDValue BasePtr = getMemoryPtr(Op);
  SDValue Mask = getNodeMask(Op);
  SDValue Chain = getNodeChain(Op);
  SDValue AVL = getNodeAVL(Op);
  // Store specific.
  SDValue Data = getStoredValue(Op);
  // Load specific.
  SDValue PassThru = getNodePassthru(Op);

  SDValue StrideV = getLoadStoreStride(Op, CDAG);

  auto DataVT = *getIdiomaticVectorType(Op.getNode());
  auto Packing = getTypePacking(DataVT);

  // TODO: Infer lower AVL from mask.
  if (!AVL)
    AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);

  // Default to the all-true mask.
  if (!Mask)
    Mask = CDAG.getConstantMask(Packing, true);

  if (IsLoad) {
    MVT LegalDataVT = getLegalVectorType(
        Packing, DataVT.getVectorElementType().getSimpleVT());

    auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},
                                 {Chain, BasePtr, StrideV, Mask, AVL});

    if (!PassThru || PassThru->isUndef())
      return NewLoadV;

    // Convert passthru to an explicit select node.
    SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,
                                 {NewLoadV, PassThru, Mask, AVL});
    SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);

    // Merge them back into one node.
    return CDAG.getMergeValues({DataV, NewLoadChainV});
  }

  // VVP_STORE
  assert(VVPOpc == VEISD::VVP_STORE);
  if (getTypeAction(*CDAG.getDAG()->getContext(), Data.getValueType()) !=
      TargetLowering::TypeLegal)
    // Doesn't lower store instruction if an operand is not lowered yet.
    // If it isn't, return SDValue().  In this way, LLVM will try to lower
    // store instruction again after lowering all operands.
    return SDValue();
  return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),
                      {Chain, Data, BasePtr, StrideV, Mask, AVL});
}

SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
                                               VECustomDAG &CDAG) const {
  auto VVPOC = *getVVPOpcode(Op.getOpcode());
  assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));

  MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
  assert(getTypePacking(DataVT) == Packing::Dense &&
         "Can only split packed load/store");
  MVT SplitDataVT = splitVectorType(DataVT);

  assert(!getNodePassthru(Op) &&
         "Should have been folded in lowering to VVP layer");

  // Analyze the operation
  SDValue PackedMask = getNodeMask(Op);
  SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
  SDValue PackPtr = getMemoryPtr(Op);
  SDValue PackData = getStoredValue(Op);
  SDValue PackStride = getLoadStoreStride(Op, CDAG);

  unsigned ChainResIdx = PackData ? 0 : 1;

  SDValue PartOps[2];

  SDValue UpperPartAVL; // we will use this for packing things back together
  for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
    // VP ops already have an explicit mask and AVL. When expanding from non-VP
    // attach those additional inputs here.
    auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);

    // Keep track of the (higher) lvl.
    if (Part == PackElem::Hi)
      UpperPartAVL = SplitTM.AVL;

    // Attach non-predicating value operands
    SmallVector<SDValue, 4> OpVec;

    // Chain
    OpVec.push_back(getNodeChain(Op));

    // Data
    if (PackData) {
      SDValue PartData =
          CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);
      OpVec.push_back(PartData);
    }

    // Ptr & Stride
    // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
    // Stride info
    // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
    OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));
    OpVec.push_back(CDAG.getSplitPtrStride(PackStride));

    // Add predicating args and generate part node
    OpVec.push_back(SplitTM.Mask);
    OpVec.push_back(SplitTM.AVL);

    if (PackData) {
      // Store
      PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);
    } else {
      // Load
      PartOps[(int)Part] =
          CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);
    }
  }

  // Merge the chains
  SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
  SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
  SDValue FusedChains =
      CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});

  // Chain only [store]
  if (PackData)
    return FusedChains;

  // Re-pack into full packed vector result
  MVT PackedVT =
      getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());
  SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],
                                    PartOps[(int)PackElem::Hi], UpperPartAVL);

  return CDAG.getMergeValues({PackedVals, FusedChains});
}

SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
                                                  VECustomDAG &CDAG) const {
  EVT DataVT = *getIdiomaticVectorType(Op.getNode());
  auto Packing = getTypePacking(DataVT);
  MVT LegalDataVT =
      getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());

  SDValue AVL = getAnnotatedNodeAVL(Op).first;
  SDValue Index = getGatherScatterIndex(Op);
  SDValue BasePtr = getMemoryPtr(Op);
  SDValue Mask = getNodeMask(Op);
  SDValue Chain = getNodeChain(Op);
  SDValue Scale = getGatherScatterScale(Op);
  SDValue PassThru = getNodePassthru(Op);
  SDValue StoredValue = getStoredValue(Op);
  if (PassThru && PassThru->isUndef())
    PassThru = SDValue();

  bool IsScatter = (bool)StoredValue;

  // TODO: Infer lower AVL from mask.
  if (!AVL)
    AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);

  // Default to the all-true mask.
  if (!Mask)
    Mask = CDAG.getConstantMask(Packing, true);

  SDValue AddressVec =
      CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
  if (IsScatter)
    return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,
                        {Chain, StoredValue, AddressVec, Mask, AVL});

  // Gather.
  SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},
                                  {Chain, AddressVec, Mask, AVL});

  if (!PassThru)
    return NewLoadV;

  // TODO: Use vvp_select
  SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,
                               {NewLoadV, PassThru, Mask, AVL});
  SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
  return CDAG.getMergeValues({DataV, NewLoadChainV});
}

SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
                                                      VECustomDAG &CDAG) const {
  LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
  MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();

  // TODO: Recognize packable load,store.
  if (isPackedVectorType(DataVT))
    return splitPackedLoadStore(Op, CDAG);

  return legalizePackedAVL(Op, CDAG);
}

SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
                                                   SelectionDAG &DAG) const {
  LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
  VECustomDAG CDAG(DAG, Op);

  // Dispatch to specialized legalization functions.
  switch (Op->getOpcode()) {
  case VEISD::VVP_LOAD:
  case VEISD::VVP_STORE:
    return legalizeInternalLoadStoreOp(Op, CDAG);
  }

  EVT IdiomVT = Op.getValueType();
  if (isPackedVectorType(IdiomVT) &&
      !supportsPackedMode(Op.getOpcode(), IdiomVT))
    return splitVectorOp(Op, CDAG);

  // TODO: Implement odd/even splitting.
  return legalizePackedAVL(Op, CDAG);
}

SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
  MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());

  auto AVLPos = getAVLPos(Op->getOpcode());
  auto MaskPos = getMaskPos(Op->getOpcode());

  SDValue PackedMask = getNodeMask(Op);
  auto AVLPair = getAnnotatedNodeAVL(Op);
  SDValue PackedAVL = AVLPair.first;
  assert(!AVLPair.second && "Expecting non pack-legalized oepration");

  // request the parts
  SDValue PartOps[2];

  SDValue UpperPartAVL; // we will use this for packing things back together
  for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
    // VP ops already have an explicit mask and AVL. When expanding from non-VP
    // attach those additional inputs here.
    auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);

    if (Part == PackElem::Hi)
      UpperPartAVL = SplitTM.AVL;

    // Attach non-predicating value operands
    SmallVector<SDValue, 4> OpVec;
    for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
      if (AVLPos && ((int)i) == *AVLPos)
        continue;
      if (MaskPos && ((int)i) == *MaskPos)
        continue;

      // Value operand
      auto PackedOperand = Op.getOperand(i);
      auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());
      SDValue PartV =
          CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);
      OpVec.push_back(PartV);
    }

    // Add predicating args and generate part node.
    OpVec.push_back(SplitTM.Mask);
    OpVec.push_back(SplitTM.AVL);
    // Emit legal VVP nodes.
    PartOps[(int)Part] =
        CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());
  }

  // Re-package vectors.
  return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],
                      PartOps[(int)PackElem::Hi], UpperPartAVL);
}

SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
                                            VECustomDAG &CDAG) const {
  LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
  // Only required for VEC and VVP ops.
  if (!isVVPOrVEC(Op->getOpcode()))
    return Op;

  // Operation already has a legal AVL.
  auto AVL = getNodeAVL(Op);
  if (isLegalAVL(AVL))
    return Op;

  // Half and round up EVL for 32bit element types.
  SDValue LegalAVL = AVL;
  MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
  if (isPackedVectorType(IdiomVT)) {
    assert(maySafelyIgnoreMask(Op) &&
           "TODO Shift predication from EVL into Mask");

    if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
      LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
    } else {
      auto ConstOne = CDAG.getConstant(1, MVT::i32);
      auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
      LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
    }
  }

  SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);

  // Copy the operand list.
  int NumOp = Op->getNumOperands();
  auto AVLPos = getAVLPos(Op->getOpcode());
  std::vector<SDValue> FixedOperands;
  for (int i = 0; i < NumOp; ++i) {
    if (AVLPos && (i == *AVLPos)) {
      FixedOperands.push_back(AnnotatedLegalAVL);
      continue;
    }
    FixedOperands.push_back(Op->getOperand(i));
  }

  // Clone the operation with fixed operands.
  auto Flags = Op->getFlags();
  SDValue NewN =
      CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
  return NewN;
}