aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/WebAssembly
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/WebAssembly')
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp149
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h5
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td45
3 files changed, 54 insertions, 145 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 64b9dc3..163bf9b 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {
- // Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
// Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
}
+
+ // Partial MLA reductions.
+ for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
+ setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal);
+ setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v8i16, Legal);
+ }
}
// As a special case, these operators use the type to mean the type to
@@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
return TargetLowering::getPointerMemTy(DL, AS);
}
-bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
- const IntrinsicInst *I) const {
- if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
- return true;
-
- EVT VT = EVT::getEVT(I->getType());
- if (VT.getSizeInBits() > 128)
- return true;
-
- auto Op1 = I->getOperand(1);
-
- if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
- unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
- if (Opcode == ISD::MUL) {
- if (isa<Instruction>(InputInst->getOperand(0)) &&
- isa<Instruction>(InputInst->getOperand(1))) {
- // dot only supports signed inputs but also support lowering unsigned.
- if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
- cast<Instruction>(InputInst->getOperand(1))->getOpcode())
- return true;
-
- EVT Op1VT = EVT::getEVT(Op1->getType());
- if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
- ((VT.getVectorElementCount() * 2 ==
- Op1VT.getVectorElementCount()) ||
- (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
- return false;
- }
- } else if (ISD::isExtOpcode(Opcode)) {
- return false;
- }
- }
- return true;
-}
-
TargetLowering::AtomicExpansionKind
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
// We have wasm instructions for these
@@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
MachinePointerInfo(SV));
}
-// Try to lower partial.reduce.add to a dot or fallback to a sequence with
-// extmul and adds.
-SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
- if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add)
- return SDValue();
-
- assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
- SDLoc DL(N);
-
- SDValue Input = N->getOperand(2);
- if (Input->getOpcode() == ISD::MUL) {
- SDValue ExtendLHS = Input->getOperand(0);
- SDValue ExtendRHS = Input->getOperand(1);
- assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
- ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
- "expected widening mul or add");
- assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
- "expected binop to use the same extend for both operands");
-
- SDValue ExtendInLHS = ExtendLHS->getOperand(0);
- SDValue ExtendInRHS = ExtendRHS->getOperand(0);
- bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
- unsigned LowOpc =
- IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
- unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
- : WebAssemblyISD::EXTEND_HIGH_U;
- SDValue LowLHS;
- SDValue LowRHS;
- SDValue HighLHS;
- SDValue HighRHS;
-
- auto AssignInputs = [&](MVT VT) {
- LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
- LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
- HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
- HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
- };
-
- if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
- if (IsSigned) {
- // i32x4.dot_i16x8_s
- SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
- ExtendInLHS, ExtendInRHS);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
- }
-
- // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
- MVT VT = MVT::v4i32;
- AssignInputs(VT);
- SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
- SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
- return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
- } else {
- assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
- "expected v16i8 input types");
- AssignInputs(MVT::v8i16);
- // Lower to a wider tree, using twice the operations compared to above.
- if (IsSigned) {
- // Use two dots
- SDValue DotLHS =
- DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
- SDValue DotRHS =
- DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
- }
-
- SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
- SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
-
- SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
- MVT::v4i32, MulLow);
- SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
- MVT::v4i32, MulHigh);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
- }
- } else {
- // Accumulate the input using extadd_pairwise.
- assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
- bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
- unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
- : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
- SDValue ExtendIn = Input->getOperand(0);
- if (ExtendIn->getValueType(0) == MVT::v8i16) {
- SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
- }
-
- assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
- "expected v16i8 input types");
- SDValue Add =
- DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
- DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
- }
-}
-
SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
@@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performVectorTruncZeroCombine(N, DCI);
case ISD::TRUNCATE:
return performTruncateCombine(N, DCI);
- case ISD::INTRINSIC_WO_CHAIN: {
- if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG))
- return AnyAllCombine;
- return performLowerPartialReduction(N, DCI.DAG);
- }
+ case ISD::INTRINSIC_WO_CHAIN:
+ return performAnyAllCombine(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI);
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index 72401a7..b33a853 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -45,8 +45,6 @@ private:
/// right decision when generating code for different targets.
const WebAssemblySubtarget *Subtarget;
- bool
- shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
bool shouldScalarizeBinop(SDValue VecOp) const override;
FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
@@ -89,8 +87,7 @@ private:
bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF,
bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
- LLVMContext &Context,
- const Type *RetTy) const override;
+ LLVMContext &Context, const Type *RetTy) const override;
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index d8948ad..1306026 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1505,6 +1505,51 @@ defm Q15MULR_SAT_S :
SIMDBinary<I16x8, int_wasm_q15mulr_sat_signed, "q15mulr_sat_s", 0x82>;
//===----------------------------------------------------------------------===//
+// Partial reductions, using: dot, extmul and extadd_pairwise
+//===----------------------------------------------------------------------===//
+// MLA: v8i16 -> v4i32
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs),
+ (v8i16 V128:$rhs))),
+ (ADD_I32x4 (DOT $lhs, $rhs), $acc)>;
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$lhs),
+ (v8i16 V128:$rhs))),
+ (ADD_I32x4 (ADD_I32x4 (EXTMUL_LOW_U_I32x4 $lhs, $rhs),
+ (EXTMUL_HIGH_U_I32x4 $lhs, $rhs)),
+ $acc)>;
+// MLA: v16i8 -> v4i32
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$lhs),
+ (v16i8 V128:$rhs))),
+ (ADD_I32x4 (ADD_I32x4 (DOT (extend_low_s_I16x8 $lhs),
+ (extend_low_s_I16x8 $rhs)),
+ (DOT (extend_high_s_I16x8 $lhs),
+ (extend_high_s_I16x8 $rhs))),
+ $acc)>;
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$lhs),
+ (v16i8 V128:$rhs))),
+ (ADD_I32x4 (ADD_I32x4 (extadd_pairwise_u_I32x4 (EXTMUL_LOW_U_I16x8 $lhs, $rhs)),
+ (extadd_pairwise_u_I32x4 (EXTMUL_HIGH_U_I16x8 $lhs, $rhs))),
+ $acc)>;
+
+// Accumulate: v8i16 -> v4i32
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in),
+ (I16x8.splat (i32 1)))),
+ (ADD_I32x4 (extadd_pairwise_s_I32x4 $in), $acc)>;
+
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in),
+ (I16x8.splat (i32 1)))),
+ (ADD_I32x4 (extadd_pairwise_u_I32x4 $in), $acc)>;
+
+// Accumulate: v16i8 -> v4i32
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$in),
+ (I8x16.splat (i32 1)))),
+ (ADD_I32x4 (extadd_pairwise_s_I32x4 (extadd_pairwise_s_I16x8 $in)),
+ $acc)>;
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$in),
+ (I8x16.splat (i32 1)))),
+ (ADD_I32x4 (extadd_pairwise_u_I32x4 (extadd_pairwise_u_I16x8 $in)),
+ $acc)>;
+
+//===----------------------------------------------------------------------===//
// Relaxed swizzle
//===----------------------------------------------------------------------===//