aboutsummaryrefslogtreecommitdiff
path: root/llvm/test/CodeGen/AMDGPU/mad-mix-hi-bf16.ll
blob: 5b2de59f7d271a95e25e0e41e2d5164a7d5f3845 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mtriple=amdgcn -mcpu=gfx1250 < %s | FileCheck -check-prefixes=GFX1250 %s

define <2 x bfloat> @v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_undeflo(bfloat %src0, bfloat %src1, bfloat %src2) #0 {
; GFX1250-LABEL: v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_undeflo:
; GFX1250:       ; %bb.0:
; GFX1250-NEXT:    s_wait_loadcnt_dscnt 0x0
; GFX1250-NEXT:    s_wait_kmcnt 0x0
; GFX1250-NEXT:    v_fma_mixhi_bf16 v0, v0, v1, v2 op_sel_hi:[1,1,1]
; GFX1250-NEXT:    s_set_pc_i64 s[30:31]
  %src0.ext = fpext bfloat %src0 to float
  %src1.ext = fpext bfloat %src1 to float
  %src2.ext = fpext bfloat %src2 to float
  %result = tail call float @llvm.fmuladd.f32(float %src0.ext, float %src1.ext, float %src2.ext)
  %cvt.result = fptrunc float %result to bfloat
  %vec.result = insertelement <2 x bfloat> undef, bfloat %cvt.result, i32 1
  ret <2 x bfloat> %vec.result
}

define <2 x bfloat> @v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_constlo(bfloat %src0, bfloat %src1, bfloat %src2) #0 {
; GFX1250-LABEL: v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_constlo:
; GFX1250:       ; %bb.0:
; GFX1250-NEXT:    s_wait_loadcnt_dscnt 0x0
; GFX1250-NEXT:    s_wait_kmcnt 0x0
; GFX1250-NEXT:    v_mov_b32_e32 v3, 0x3f80
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX1250-NEXT:    v_fma_mixhi_bf16 v3, v0, v1, v2 op_sel_hi:[1,1,1]
; GFX1250-NEXT:    v_mov_b32_e32 v0, v3
; GFX1250-NEXT:    s_set_pc_i64 s[30:31]
  %src0.ext = fpext bfloat %src0 to float
  %src1.ext = fpext bfloat %src1 to float
  %src2.ext = fpext bfloat %src2 to float
  %result = tail call float @llvm.fmuladd.f32(float %src0.ext, float %src1.ext, float %src2.ext)
  %cvt.result = fptrunc float %result to bfloat
  %vec.result = insertelement <2 x bfloat> <bfloat 1.0, bfloat undef>, bfloat %cvt.result, i32 1
  ret <2 x bfloat> %vec.result
}

define <2 x bfloat> @v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_reglo(bfloat %src0, bfloat %src1, bfloat %src2, bfloat %lo) #0 {
; GFX1250-LABEL: v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_reglo:
; GFX1250:       ; %bb.0:
; GFX1250-NEXT:    s_wait_loadcnt_dscnt 0x0
; GFX1250-NEXT:    s_wait_kmcnt 0x0
; GFX1250-NEXT:    v_fma_mixhi_bf16 v3, v0, v1, v2 op_sel_hi:[1,1,1]
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1)
; GFX1250-NEXT:    v_mov_b32_e32 v0, v3
; GFX1250-NEXT:    s_set_pc_i64 s[30:31]
  %src0.ext = fpext bfloat %src0 to float
  %src1.ext = fpext bfloat %src1 to float
  %src2.ext = fpext bfloat %src2 to float
  %result = tail call float @llvm.fmuladd.f32(float %src0.ext, float %src1.ext, float %src2.ext)
  %cvt.result = fptrunc float %result to bfloat
  %vec = insertelement <2 x bfloat> undef, bfloat %lo, i32 0
  %vec.result = insertelement <2 x bfloat> %vec, bfloat %cvt.result, i32 1
  ret <2 x bfloat> %vec.result
}

define i32 @v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_intpack(bfloat %src0, bfloat %src1, bfloat %src2) #0 {
; GFX1250-LABEL: v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_intpack:
; GFX1250:       ; %bb.0:
; GFX1250-NEXT:    s_wait_loadcnt_dscnt 0x0
; GFX1250-NEXT:    s_wait_kmcnt 0x0
; GFX1250-NEXT:    v_fma_mixlo_bf16 v0, v0, v1, v2 op_sel_hi:[1,1,1]
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1)
; GFX1250-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
; GFX1250-NEXT:    s_set_pc_i64 s[30:31]
  %src0.ext = fpext bfloat %src0 to float
  %src1.ext = fpext bfloat %src1 to float
  %src2.ext = fpext bfloat %src2 to float
  %result = tail call float @llvm.fmuladd.f32(float %src0.ext, float %src1.ext, float %src2.ext)
  %cvt.result = fptrunc float %result to bfloat
  %bc = bitcast bfloat %cvt.result to i16
  %ext = zext i16 %bc to i32
  %shr = shl i32 %ext, 16
  ret i32 %shr
}

define i32 @v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_intpack_sext(bfloat %src0, bfloat %src1, bfloat %src2) #0 {
; GFX1250-LABEL: v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_intpack_sext:
; GFX1250:       ; %bb.0:
; GFX1250-NEXT:    s_wait_loadcnt_dscnt 0x0
; GFX1250-NEXT:    s_wait_kmcnt 0x0
; GFX1250-NEXT:    v_fma_mixlo_bf16 v0, v0, v1, v2 op_sel_hi:[1,1,1]
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1)
; GFX1250-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
; GFX1250-NEXT:    s_set_pc_i64 s[30:31]
  %src0.ext = fpext bfloat %src0 to float
  %src1.ext = fpext bfloat %src1 to float
  %src2.ext = fpext bfloat %src2 to float
  %result = tail call float @llvm.fmuladd.f32(float %src0.ext, float %src1.ext, float %src2.ext)
  %cvt.result = fptrunc float %result to bfloat
  %bc = bitcast bfloat %cvt.result to i16
  %ext = sext i16 %bc to i32
  %shr = shl i32 %ext, 16
  ret i32 %shr
}

define <2 x bfloat> @v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_undeflo_clamp_precvt(bfloat %src0, bfloat %src1, bfloat %src2) #0 {
; GFX1250-LABEL: v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_undeflo_clamp_precvt:
; GFX1250:       ; %bb.0:
; GFX1250-NEXT:    s_wait_loadcnt_dscnt 0x0
; GFX1250-NEXT:    s_wait_kmcnt 0x0
; GFX1250-NEXT:    v_fma_mix_f32_bf16 v0, v0, v1, v2 op_sel_hi:[1,1,1] clamp
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX1250-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
; GFX1250-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
; GFX1250-NEXT:    s_set_pc_i64 s[30:31]
  %src0.ext = fpext bfloat %src0 to float
  %src1.ext = fpext bfloat %src1 to float
  %src2.ext = fpext bfloat %src2 to float
  %result = tail call float @llvm.fmuladd.f32(float %src0.ext, float %src1.ext, float %src2.ext)
  %max = call float @llvm.maxnum.f32(float %result, float 0.0)
  %clamp = call float @llvm.minnum.f32(float %max, float 1.0)
  %cvt.result = fptrunc float %clamp to bfloat
  %vec.result = insertelement <2 x bfloat> undef, bfloat %cvt.result, i32 1
  ret <2 x bfloat> %vec.result
}

define <2 x bfloat> @v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_undeflo_clamp_postcvt(bfloat %src0, bfloat %src1, bfloat %src2) #0 {
; GFX1250-LABEL: v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_undeflo_clamp_postcvt:
; GFX1250:       ; %bb.0:
; GFX1250-NEXT:    s_wait_loadcnt_dscnt 0x0
; GFX1250-NEXT:    s_wait_kmcnt 0x0
; GFX1250-NEXT:    v_fma_mixlo_bf16 v0, v0, v1, v2 op_sel_hi:[1,1,1]
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX1250-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
; GFX1250-NEXT:    v_max_num_f32_e32 v0, 0, v0
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX1250-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
; GFX1250-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX1250-NEXT:    v_min_num_f32_e32 v0, 1.0, v0
; GFX1250-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1)
; GFX1250-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
; GFX1250-NEXT:    s_set_pc_i64 s[30:31]
  %src0.ext = fpext bfloat %src0 to float
  %src1.ext = fpext bfloat %src1 to float
  %src2.ext = fpext bfloat %src2 to float
  %result = tail call float @llvm.fmuladd.f32(float %src0.ext, float %src1.ext, float %src2.ext)
  %cvt.result = fptrunc float %result to bfloat
  %max = call bfloat @llvm.maxnum.bf16(bfloat %cvt.result, bfloat 0.0)
  %clamp = call bfloat @llvm.minnum.bf16(bfloat %max, bfloat 1.0)
  %vec.result = insertelement <2 x bfloat> undef, bfloat %clamp, i32 1
  ret <2 x bfloat> %vec.result
}

define <2 x bfloat> @v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_undeflo_clamp_postcvt_multi_use(bfloat %src0, bfloat %src1, bfloat %src2) #0 {
; GFX1250-LABEL: v_mad_mixhi_bf16_bf16lo_bf16lo_bf16lo_undeflo_clamp_postcvt_multi_use:
; GFX1250:       ; %bb.0:
; GFX1250-NEXT:    s_wait_loadcnt_dscnt 0x0
; GFX1250-NEXT:    s_wait_kmcnt 0x0
; GFX1250-NEXT:    v_fma_mixlo_bf16 v1, v0, v1, v2 op_sel_hi:[1,1,1]
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX1250-NEXT:    v_lshlrev_b32_e32 v0, 16, v1
; GFX1250-NEXT:    v_max_num_f32_e32 v0, 0, v0
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX1250-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
; GFX1250-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX1250-NEXT:    v_min_num_f32_e32 v0, 1.0, v0
; GFX1250-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
; GFX1250-NEXT:    s_delay_alu instid0(VALU_DEP_1)
; GFX1250-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
; GFX1250-NEXT:    s_wait_storecnt 0x0
; GFX1250-NEXT:    global_store_b16 v[0:1], v1, off scope:SCOPE_SYS
; GFX1250-NEXT:    s_wait_storecnt 0x0
; GFX1250-NEXT:    s_set_pc_i64 s[30:31]
  %src0.ext = fpext bfloat %src0 to float
  %src1.ext = fpext bfloat %src1 to float
  %src2.ext = fpext bfloat %src2 to float
  %result = tail call float @llvm.fmuladd.f32(float %src0.ext, float %src1.ext, float %src2.ext)
  %cvt.result = fptrunc float %result to bfloat
  store volatile bfloat %cvt.result, ptr addrspace(1) undef
  %max = call bfloat @llvm.maxnum.bf16(bfloat %cvt.result, bfloat 0.0)
  %clamp = call bfloat @llvm.minnum.bf16(bfloat %max, bfloat 1.0)
  %vec.result = insertelement <2 x bfloat> undef, bfloat %clamp, i32 1
  ret <2 x bfloat> %vec.result
}

declare bfloat @llvm.minnum.bf16(bfloat, bfloat) #1
declare bfloat @llvm.maxnum.bf16(bfloat, bfloat) #1
declare float @llvm.minnum.f32(float, float) #1
declare float @llvm.maxnum.f32(float, float) #1
declare float @llvm.fmuladd.f32(float, float, float) #1
declare <2 x float> @llvm.fmuladd.v2f32(<2 x float>, <2 x float>, <2 x float>) #1

attributes #0 = { nounwind "denormal-fp-math-f32"="preserve-sign,preserve-sign" }
attributes #1 = { nounwind readnone speculatable }