diff options
Diffstat (limited to 'llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp | 149 |
1 files changed, 8 insertions, 141 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 64b9dc3..163bf9b 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( // SIMD-specific configuration if (Subtarget->hasSIMD128()) { - // Combine partial.reduce.add before legalization gets confused. setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); // Combine wide-vector muls, with extend inputs, to extmul_half. @@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom); setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom); } + + // Partial MLA reductions. + for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) { + setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal); + setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v8i16, Legal); + } } // As a special case, these operators use the type to mean the type to @@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL, return TargetLowering::getPointerMemTy(DL, AS); } -bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic( - const IntrinsicInst *I) const { - if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add) - return true; - - EVT VT = EVT::getEVT(I->getType()); - if (VT.getSizeInBits() > 128) - return true; - - auto Op1 = I->getOperand(1); - - if (auto *InputInst = dyn_cast<Instruction>(Op1)) { - unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode()); - if (Opcode == ISD::MUL) { - if (isa<Instruction>(InputInst->getOperand(0)) && - isa<Instruction>(InputInst->getOperand(1))) { - // dot only supports signed inputs but also support lowering unsigned. - if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() != - cast<Instruction>(InputInst->getOperand(1))->getOpcode()) - return true; - - EVT Op1VT = EVT::getEVT(Op1->getType()); - if (Op1VT.getVectorElementType() == VT.getVectorElementType() && - ((VT.getVectorElementCount() * 2 == - Op1VT.getVectorElementCount()) || - (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount()))) - return false; - } - } else if (ISD::isExtOpcode(Opcode)) { - return false; - } - } - return true; -} - TargetLowering::AtomicExpansionKind WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { // We have wasm instructions for these @@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op, MachinePointerInfo(SV)); } -// Try to lower partial.reduce.add to a dot or fallback to a sequence with -// extmul and adds. -SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) { - assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN); - if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add) - return SDValue(); - - assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32"); - SDLoc DL(N); - - SDValue Input = N->getOperand(2); - if (Input->getOpcode() == ISD::MUL) { - SDValue ExtendLHS = Input->getOperand(0); - SDValue ExtendRHS = Input->getOperand(1); - assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) && - ISD::isExtOpcode(ExtendRHS.getOpcode())) && - "expected widening mul or add"); - assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() && - "expected binop to use the same extend for both operands"); - - SDValue ExtendInLHS = ExtendLHS->getOperand(0); - SDValue ExtendInRHS = ExtendRHS->getOperand(0); - bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND; - unsigned LowOpc = - IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U; - unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S - : WebAssemblyISD::EXTEND_HIGH_U; - SDValue LowLHS; - SDValue LowRHS; - SDValue HighLHS; - SDValue HighRHS; - - auto AssignInputs = [&](MVT VT) { - LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS); - LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS); - HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS); - HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS); - }; - - if (ExtendInLHS->getValueType(0) == MVT::v8i16) { - if (IsSigned) { - // i32x4.dot_i16x8_s - SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, - ExtendInLHS, ExtendInRHS); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot); - } - - // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs))) - MVT VT = MVT::v4i32; - AssignInputs(VT); - SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS); - SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS); - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh); - return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add); - } else { - assert(ExtendInLHS->getValueType(0) == MVT::v16i8 && - "expected v16i8 input types"); - AssignInputs(MVT::v8i16); - // Lower to a wider tree, using twice the operations compared to above. - if (IsSigned) { - // Use two dots - SDValue DotLHS = - DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS); - SDValue DotRHS = - DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS); - SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); - } - - SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS); - SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS); - - SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, - MVT::v4i32, MulLow); - SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, - MVT::v4i32, MulHigh); - SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); - } - } else { - // Accumulate the input using extadd_pairwise. - assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend"); - bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND; - unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S - : WebAssemblyISD::EXT_ADD_PAIRWISE_U; - SDValue ExtendIn = Input->getOperand(0); - if (ExtendIn->getValueType(0) == MVT::v8i16) { - SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); - } - - assert(ExtendIn->getValueType(0) == MVT::v16i8 && - "expected v16i8 input types"); - SDValue Add = - DAG.getNode(PairwiseOpc, DL, MVT::v4i32, - DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn)); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); - } -} - SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op, SelectionDAG &DAG) const { MachineFunction &MF = DAG.getMachineFunction(); @@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, return performVectorTruncZeroCombine(N, DCI); case ISD::TRUNCATE: return performTruncateCombine(N, DCI); - case ISD::INTRINSIC_WO_CHAIN: { - if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG)) - return AnyAllCombine; - return performLowerPartialReduction(N, DCI.DAG); - } + case ISD::INTRINSIC_WO_CHAIN: + return performAnyAllCombine(N, DCI.DAG); case ISD::MUL: return performMulCombine(N, DCI); } |