aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp')
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp149
1 files changed, 8 insertions, 141 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 64b9dc3..163bf9b 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {
- // Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
// Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
}
+
+ // Partial MLA reductions.
+ for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
+ setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal);
+ setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v8i16, Legal);
+ }
}
// As a special case, these operators use the type to mean the type to
@@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
return TargetLowering::getPointerMemTy(DL, AS);
}
-bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
- const IntrinsicInst *I) const {
- if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
- return true;
-
- EVT VT = EVT::getEVT(I->getType());
- if (VT.getSizeInBits() > 128)
- return true;
-
- auto Op1 = I->getOperand(1);
-
- if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
- unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
- if (Opcode == ISD::MUL) {
- if (isa<Instruction>(InputInst->getOperand(0)) &&
- isa<Instruction>(InputInst->getOperand(1))) {
- // dot only supports signed inputs but also support lowering unsigned.
- if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
- cast<Instruction>(InputInst->getOperand(1))->getOpcode())
- return true;
-
- EVT Op1VT = EVT::getEVT(Op1->getType());
- if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
- ((VT.getVectorElementCount() * 2 ==
- Op1VT.getVectorElementCount()) ||
- (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
- return false;
- }
- } else if (ISD::isExtOpcode(Opcode)) {
- return false;
- }
- }
- return true;
-}
-
TargetLowering::AtomicExpansionKind
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
// We have wasm instructions for these
@@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
MachinePointerInfo(SV));
}
-// Try to lower partial.reduce.add to a dot or fallback to a sequence with
-// extmul and adds.
-SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
- if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add)
- return SDValue();
-
- assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
- SDLoc DL(N);
-
- SDValue Input = N->getOperand(2);
- if (Input->getOpcode() == ISD::MUL) {
- SDValue ExtendLHS = Input->getOperand(0);
- SDValue ExtendRHS = Input->getOperand(1);
- assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
- ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
- "expected widening mul or add");
- assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
- "expected binop to use the same extend for both operands");
-
- SDValue ExtendInLHS = ExtendLHS->getOperand(0);
- SDValue ExtendInRHS = ExtendRHS->getOperand(0);
- bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
- unsigned LowOpc =
- IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
- unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
- : WebAssemblyISD::EXTEND_HIGH_U;
- SDValue LowLHS;
- SDValue LowRHS;
- SDValue HighLHS;
- SDValue HighRHS;
-
- auto AssignInputs = [&](MVT VT) {
- LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
- LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
- HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
- HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
- };
-
- if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
- if (IsSigned) {
- // i32x4.dot_i16x8_s
- SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
- ExtendInLHS, ExtendInRHS);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
- }
-
- // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
- MVT VT = MVT::v4i32;
- AssignInputs(VT);
- SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
- SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
- return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
- } else {
- assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
- "expected v16i8 input types");
- AssignInputs(MVT::v8i16);
- // Lower to a wider tree, using twice the operations compared to above.
- if (IsSigned) {
- // Use two dots
- SDValue DotLHS =
- DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
- SDValue DotRHS =
- DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
- }
-
- SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
- SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
-
- SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
- MVT::v4i32, MulLow);
- SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
- MVT::v4i32, MulHigh);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
- }
- } else {
- // Accumulate the input using extadd_pairwise.
- assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
- bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
- unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
- : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
- SDValue ExtendIn = Input->getOperand(0);
- if (ExtendIn->getValueType(0) == MVT::v8i16) {
- SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
- }
-
- assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
- "expected v16i8 input types");
- SDValue Add =
- DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
- DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
- }
-}
-
SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
@@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performVectorTruncZeroCombine(N, DCI);
case ISD::TRUNCATE:
return performTruncateCombine(N, DCI);
- case ISD::INTRINSIC_WO_CHAIN: {
- if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG))
- return AnyAllCombine;
- return performLowerPartialReduction(N, DCI.DAG);
- }
+ case ISD::INTRINSIC_WO_CHAIN:
+ return performAnyAllCombine(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI);
}