llvm.org GIT mirror llvm / 9e3c94a
PTX 6.3 extends `wmma` instruction to support s8/u8/s4/u4/b1 -> s32. All of the new instructions are still handled mostly by tablegen. I've slightly refactored the code to drive intrinsic/instruction generation from a master list of supported variants, so all irregularities have to be implemented in one place only. The test generation script wmma.py has been refactored in a similar way. Differential Revision: https://reviews.llvm.org/D60015 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@359247 91177308-0d34-0410-b5e6-96231b3b80d8 Artem Belevich 1 year, 7 months ago
5 changed file(s) with 811 addition(s) and 169 deletion(s). Raw diff Collapse all Expand all
4949 string geom = Geom;
5050 string frag = Frag;
5151 string ptx_elt_type = PtxEltType;
52 string gft = Geom#":"#Frag#":"#ptx_elt_type;
5253 string ft = frag#":"#ptx_elt_type;
5354 list regs = !cond(
5455 // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
5960 !eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
6061 !eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
6162 !eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret,
62 !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret);
63 !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret,
64
65 // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
66 !eq(gft,"m16n16k16:a:u8") : RepLLVMType<2, llvm_i32_ty>.ret,
67 !eq(gft,"m16n16k16:a:s8") : RepLLVMType<2, llvm_i32_ty>.ret,
68 !eq(gft,"m16n16k16:b:u8") : RepLLVMType<2, llvm_i32_ty>.ret,
69 !eq(gft,"m16n16k16:b:s8") : RepLLVMType<2, llvm_i32_ty>.ret,
70 !eq(gft,"m16n16k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
71 !eq(gft,"m16n16k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
72
73 !eq(gft,"m8n32k16:a:u8") : [llvm_i32_ty],
74 !eq(gft,"m8n32k16:a:s8") : [llvm_i32_ty],
75 !eq(gft,"m8n32k16:b:u8") : RepLLVMType<4, llvm_i32_ty>.ret,
76 !eq(gft,"m8n32k16:b:s8") : RepLLVMType<4, llvm_i32_ty>.ret,
77 !eq(gft,"m8n32k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
78 !eq(gft,"m8n32k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
79
80 !eq(gft,"m32n8k16:a:u8") : RepLLVMType<4, llvm_i32_ty>.ret,
81 !eq(gft,"m32n8k16:a:s8") : RepLLVMType<4, llvm_i32_ty>.ret,
82 !eq(gft,"m32n8k16:b:u8") : [llvm_i32_ty],
83 !eq(gft,"m32n8k16:b:s8") : [llvm_i32_ty],
84 !eq(gft,"m32n8k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
85 !eq(gft,"m32n8k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
86
87 // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
88 !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
89 !eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty],
90 !eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty],
91 !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
92 !eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty],
93 !eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty],
94 !eq(gft,"m8n8k128:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
95 !eq(gft,"m8n8k128:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
96 !eq(gft,"m8n8k32:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
97 !eq(gft,"m8n8k32:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
98 );
6399 }
64100
65101 class WMMA_NAME_LDST {
83119 # !if(WithStride, "_stride", "");
84120 }
85121
86 class WMMA_NAME_MMA
87 WMMA_REGS C, WMMA_REGS D,
88 int Satfinite> {
122 class MMA_SIGNATURE> {
123 list id_frags = !cond(
124 // int and sub-int ops are identified by input type.
125 !eq(A.ptx_elt_type, "s8") : [A],
126 !eq(A.ptx_elt_type, "u8") : [A],
127 !eq(A.ptx_elt_type, "s4") : [A],
128 !eq(A.ptx_elt_type, "u4") : [A],
129 !eq(A.ptx_elt_type, "b1") : [A],
130 // the rest are FP ops identified by accumulator & result type.
131 1: [D, C]
132 );
133 string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type));
134 }
135
136 class WMMA_NAME_MMA
137 WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
138 string signature = MMA_SIGNATURE.ret;
89139 string llvm = "llvm.nvvm.wmma."
90 # C.geom
140 # A.geom
91141 # ".mma"
92142 # "." # ALayout
93143 # "." # BLayout
94 # "." # D.ptx_elt_type // Intrinsic encodes 'd' first.
95 # "." # C.ptx_elt_type
144 # signature
96145 # !if(Satfinite, ".satfinite", "");
97146
98147 string record = !subst(".", "_",
99148 !subst("llvm.", "int_", llvm));
149 }
150
151 // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
152 // Geom: list of supported geometries.
153 // TypeN: PTX type of the corresponding fragment's element.
154 // TypeB and TypeD may be empty if it must match that of TypeA or TypeC.
155 class MMA_OPS Geom, list TypeA, list TypeB,
156 list TypeC, list TypeD> {
157 list> ret =
158 !foldl([]>, Geom, t1, geom, !listconcat(t1,
159 !foldl([]>, TypeA, t2, type_a, !listconcat(t2,
160 !foldl([]>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3,
161 !foldl([]>, TypeC, t4, type_c, !listconcat(t4,
162 !foldl([]>, !if(!size(TypeC), TypeC, [type_c]), t5, type_d, !listconcat(t5,
163 [[WMMA_REGS,
164 WMMA_REGS,
165 WMMA_REGS,
166 WMMA_REGS]]))))))))));
167 // Debugging aid for readable representation of the list above.
168 list> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]);
169 }
170
171 class MMA_LDST_OPS Geom, list Frags, list Types> {
172 list ret =
173 !foldl([], Geom, t1, geom, !listconcat(t1,
174 !foldl([], Frags, t2, frag, !listconcat(t2,
175 !foldl([], Types, t3, type, !listconcat(t3,
176 [WMMA_REGS]))))));
177 // Debugging aid for readable representation of the list above.
178 list ops = !foreach(x, ret, x.gft);
179 }
180
181
182
183 // Creates list of valid combinations of fragments. This is the master list that
184 // drives generation of corresponding intrinsics and instructions.
185 class NVVM_MMA_OPS {
186 list> fp_mma_ops = MMA_OPS<
187 ["m16n16k16", "m32n8k16", "m8n32k16"],
188 ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
189 list> int_mma_ops = MMA_OPS<
190 ["m16n16k16", "m32n8k16", "m8n32k16"],
191 ["s8", "u8"], [], ["s32"], []>.ret;
192 list> subint_mma_ops = MMA_OPS<
193 ["m8n8k32"],
194 ["s4", "u4"], [], ["s32"], []>.ret;
195 list> bit_mma_ops = MMA_OPS<
196 ["m8n8k128"],
197 ["b1"], [], ["s32"], []>.ret;
198 list> all_mma_ops = !listconcat(fp_mma_ops, int_mma_ops,
199 subint_mma_ops, bit_mma_ops);
200
201 list ldst_ab_ops = MMA_LDST_OPS<
202 ["m16n16k16", "m32n8k16", "m8n32k16"],
203 ["a", "b"], ["f16", "u8", "s8"]>.ret;
204 list ldst_cd_ops = MMA_LDST_OPS<
205 ["m16n16k16", "m32n8k16", "m8n32k16"],
206 ["c", "d"], ["f16", "f32", "s32"]>.ret;
207 list ldst_subint_ab_ops = MMA_LDST_OPS<
208 ["m8n8k32"], ["a", "b"], ["s4","u4"]>.ret;
209 list ldst_bit_ab_ops = MMA_LDST_OPS<
210 ["m8n8k128"], ["a", "b"], ["b1"]>.ret;
211 list ldst_subint_cd_ops = MMA_LDST_OPS<
212 ["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]>.ret;
213 list all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
214 ldst_subint_ab_ops,
215 ldst_bit_ab_ops,
216 ldst_subint_cd_ops);
217 // Separate A/B/C fragments (loads) from D (stores).
218 list all_ld_ops = !foldl([], all_ldst_ops, a, b,
219 !listconcat(a, !if(!eq(b.frag,"d"), [],[b])));
220 list all_st_ops = !foldl([], all_ldst_ops, a, b,
221 !listconcat(a, !if(!eq(b.frag,"d"), [b],[])));
222 }
223
224 def NVVM_MMA_OPS : NVVM_MMA_OPS;
225
226 // Returns [1] if this combination of layout/satf is supported, [] otherwise.
227 // MMA ops must provide all parameters. Loads and stores -- only frags and layout_a.
228 // The class is used to prevent generation of records for the unsupported variants.
229 // E.g.
230 // foreach _ = NVVM_MMA_SUPPORTED<...>.ret in =
231 // def : FOO<>; // The record will only be defined for supported ops.
232 //
233 class NVVM_MMA_SUPPORTED frags, string layout_a, string layout_b="-", int satf=-1> {
234 // MMA ops check both layouts.
235 string mma = frags[0].ptx_elt_type
236 # ":" # layout_a
237 # ":" # layout_b;
238 // Load ops only need type/fragment/layout.
239 string ld = frags[0].ptx_elt_type
240 # ":" # frags[0].frag
241 # ":" # layout_a
242 ;
243 string ldf = frags[0].ptx_elt_type
244 # ":" # frags[0].frag
245 ;
246 string t = frags[0].ptx_elt_type;
247 list ret = !cond(
248 // Sub-int MMA only supports fixed A/B layout.
249 // b1 does not support .satf.
250 !eq(mma#":"#satf, "b1:row:col:0") : [1],
251 !eq(mma, "s4:row:col") : [1],
252 !eq(mma, "u4:row:col") : [1],
253 !eq(mma, "s4:row:col") : [1],
254 !eq(mma, "u4:row:col") : [1],
255 // Sub-int load/stores have fixed layout for A and B.
256 !and(!eq(layout_b, "-"), // It's a Load or Store op
257 !or(!eq(ld, "b1:a:row"),
258 !eq(ld, "b1:b:col"),
259 !eq(ldf, "b1:c"),
260 !eq(ldf, "b1:d"),
261 !eq(ld, "s4:a:row"),
262 !eq(ld, "s4:b:col"),
263 !eq(ldf, "s4:c"),
264 !eq(ldf, "s4:d"),
265 !eq(ld, "u4:a:row"),
266 !eq(ld, "u4:b:col"),
267 !eq(ldf, "u4:c"),
268 !eq(ldf, "u4:d"))) : [1],
269 // All other sub-int ops are not supported.
270 !eq(t, "b1") : [],
271 !eq(t, "s4") : [],
272 !eq(t, "u4") : [],
273 // All other (non sub-int) are OK.
274 1: [1]
275 );
100276 }
101277
102278 let TargetPrefix = "nvvm" in {
39694145 WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>;
39704146
39714147 // Create all load/store variants
3972 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
3973 foreach layout = ["row", "col"] in {
3974 foreach stride = [0, 1] in {
3975 foreach frag = [WMMA_REGS,
3976 WMMA_REGS,
3977 WMMA_REGS,
3978 WMMA_REGS] in {
3979 def WMMA_NAME_LDST<"load", frag, layout, stride>.record
4148 foreach layout = ["row", "col"] in {
4149 foreach stride = [0, 1] in {
4150 foreach frag = NVVM_MMA_OPS.all_ld_ops in
4151 foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
4152 def WMMA_NAME_LDST<"load", frag, layout, stride>.record
39804153 : NVVM_WMMA_LD;
3981 }
3982 foreach frag = [WMMA_REGS,
3983 WMMA_REGS] in {
3984 def WMMA_NAME_LDST<"store", frag, layout, stride>.record
4154 foreach frag = NVVM_MMA_OPS.all_st_ops in
4155 foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
4156 def WMMA_NAME_LDST<"store", frag, layout, stride>.record
39854157 : NVVM_WMMA_ST;
3986 }
3987 }
39884158 }
39894159 }
39904160
39914161 // WMMA.MMA
3992 class NVVM_WMMA_MMA
3993 WMMA_REGS C, WMMA_REGS D, int Satfinite>
4162 class NVVM_WMMA_MMA
4163 WMMA_REGS A, WMMA_REGS B,
4164 WMMA_REGS C, WMMA_REGS D>
39944165 : Intrinsic
3995 !listconcat(
3996 WMMA_REGS.regs,
3997 WMMA_REGS.regs,
3998 C.regs),
4166 !listconcat(A.regs, B.regs, C.regs),
39994167 [IntrNoMem],
4000 WMMA_NAME_MMA.llvm>;
4001
4002 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
4003 foreach layout_a = ["row", "col"] in {
4004 foreach layout_b = ["row", "col"] in {
4005 foreach frag_c = [WMMA_REGS,
4006 WMMA_REGS] in {
4007 foreach frag_d = [WMMA_REGS,
4008 WMMA_REGS] in {
4009 foreach satf = [0, 1] in {
4010 def WMMA_NAME_MMA.record
4011 : NVVM_WMMA_MMA;
4012 }
4168 WMMA_NAME_MMA.llvm>;
4169
4170 foreach layout_a = ["row", "col"] in {
4171 foreach layout_b = ["row", "col"] in {
4172 foreach satf = [0, 1] in {
4173 foreach op = NVVM_MMA_OPS.all_mma_ops in {
4174 foreach _ = NVVM_MMA_SUPPORTED.ret in {
4175 def WMMA_NAME_MMA
4176 op[0], op[1], op[2], op[3]>.record
4177 : NVVM_WMMA_MMA
4178 op[0], op[1], op[2], op[3]>;
40134179 }
40144180 }
4015 }
4016 }
4017 }
4181 } // satf
4182 } // layout_b
4183 } // layout_a
40184184
40194185 } // let TargetPrefix = "nvvm"
34993499 Info.align = 16;
35003500 return true;
35013501 }
3502 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
3503 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
3504 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
3505 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
3506 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
3507 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
3508 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
3509 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
3510 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
3511 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
3512 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
3513 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
3514 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
3515 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
3516 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
3517 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: {
3518 Info.opc = ISD::INTRINSIC_W_CHAIN;
3519 Info.memVT = MVT::v2i32;
3520 Info.ptrVal = I.getArgOperand(0);
3521 Info.offset = 0;
3522 Info.flags = MachineMemOperand::MOLoad;
3523 Info.align = 8;
3524 return true;
3525 }
3526
3527 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
3528 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
3529 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
3530 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
3531 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
3532 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
3533 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
3534 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
3535
3536 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
3537 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
3538 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
3539 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
3540 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
3541 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
3542 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
3543 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: {
3544 Info.opc = ISD::INTRINSIC_W_CHAIN;
3545 Info.memVT = MVT::v4i32;
3546 Info.ptrVal = I.getArgOperand(0);
3547 Info.offset = 0;
3548 Info.flags = MachineMemOperand::MOLoad;
3549 Info.align = 16;
3550 return true;
3551 }
3552
3553 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
3554 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
3555 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
3556 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
3557 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
3558 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
3559 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
3560 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
3561
3562 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
3563 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
3564 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
3565 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
3566 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
3567 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
3568 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
3569 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
3570 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
3571 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
3572 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
3573 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
3574 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
3575 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
3576 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
3577 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
3578 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
3579 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
3580 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
3581 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: {
3582 Info.opc = ISD::INTRINSIC_W_CHAIN;
3583 Info.memVT = MVT::i32;
3584 Info.ptrVal = I.getArgOperand(0);
3585 Info.offset = 0;
3586 Info.flags = MachineMemOperand::MOLoad;
3587 Info.align = 4;
3588 return true;
3589 }
35023590
35033591 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
35043592 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
35423630 return true;
35433631 }
35443632
3633 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
3634 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
3635 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
3636 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
3637 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
3638 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
3639 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
3640 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
3641 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
3642 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
3643 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
3644 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
3645 Info.opc = ISD::INTRINSIC_W_CHAIN;
3646 Info.memVT = MVT::v8i32;
3647 Info.ptrVal = I.getArgOperand(0);
3648 Info.offset = 0;
3649 Info.flags = MachineMemOperand::MOLoad;
3650 Info.align = 16;
3651 return true;
3652 }
3653
3654 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
3655 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
3656 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
3657 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
3658 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
3659 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
3660 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
3661 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: {
3662 Info.opc = ISD::INTRINSIC_W_CHAIN;
3663 Info.memVT = MVT::v2i32;
3664 Info.ptrVal = I.getArgOperand(0);
3665 Info.offset = 0;
3666 Info.flags = MachineMemOperand::MOLoad;
3667 Info.align = 8;
3668 return true;
3669 }
3670
35453671 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
35463672 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
35473673 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
35813707 Info.offset = 0;
35823708 Info.flags = MachineMemOperand::MOStore;
35833709 Info.align = 16;
3710 return true;
3711 }
3712
3713 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
3714 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
3715 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
3716 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
3717 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
3718 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
3719 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
3720 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
3721 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
3722 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
3723 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
3724 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
3725 Info.opc = ISD::INTRINSIC_VOID;
3726 Info.memVT = MVT::v8i32;
3727 Info.ptrVal = I.getArgOperand(0);
3728 Info.offset = 0;
3729 Info.flags = MachineMemOperand::MOStore;
3730 Info.align = 16;
3731 return true;
3732 }
3733
3734 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
3735 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
3736 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
3737 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
3738 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
3739 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
3740 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
3741 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
3742 Info.opc = ISD::INTRINSIC_VOID;
3743 Info.memVT = MVT::v2i32;
3744 Info.ptrVal = I.getArgOperand(0);
3745 Info.offset = 0;
3746 Info.flags = MachineMemOperand::MOStore;
3747 Info.align = 8;
35843748 return true;
35853749 }
35863750
141141 def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">;
142142 def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
143143 def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
144 def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
144145
145146 def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
146147 def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">;
148 def hasSM72 : Predicate<"Subtarget->getSmVersion() >= 72">;
149 def hasSM75 : Predicate<"Subtarget->getSmVersion() >= 75">;
147150
148151 def useShortPtr : Predicate<"useShortPointers()">;
149152 def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
74057405 // In addition to target-independent fields provided by WMMA_REGS, it adds
74067406 // the fields commonly used to implement specific PTX instruction -- register
74077407 // types and names, constraints, parts of assembly, etc.
7408 class WMMA_REGINFO
7409 : WMMA_REGS {
7408 class WMMA_REGINFO
7409 : WMMA_REGS {
74107410 // NVPTX register types used to carry fragment data.
74117411 NVPTXRegClass regclass = !cond(
7412 !eq(PtxEltType, "f16") : Float16x2Regs,
7413 !eq(PtxEltType, "f32") : Float32Regs);
7412 !eq(ptx_elt_type, "f16") : Float16x2Regs,
7413 !eq(ptx_elt_type, "f32") : Float32Regs,
7414 !eq(ptx_elt_type, "s32") : Int32Regs,
7415 !eq(ptx_elt_type, "s8") : Int32Regs,
7416 !eq(ptx_elt_type, "u8") : Int32Regs,
7417 !eq(ptx_elt_type, "s4") : Int32Regs,
7418 !eq(ptx_elt_type, "u4") : Int32Regs,
7419 !eq(ptx_elt_type, "b1") : Int32Regs);
74147420
74157421 // Instruction input/output arguments for the fragment.
74167422 list ptx_regs = !foreach(tmp, regs, regclass);
74327438 // all fragments of the instruction are viable.
74337439 list Predicates = !cond(
74347440 // fp16 -> fp16/fp32 @ m16n16k16
7435 !and(!eq(Geom, "m16n16k16"),
7436 !or(!eq(PtxEltType, "f16"),
7437 !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60],
7441 !and(!eq(geom, "m16n16k16"),
7442 !or(!eq(ptx_elt_type, "f16"),
7443 !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60],
74387444
74397445 // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
7440 !and(!or(!eq(Geom, "m8n32k16"),
7441 !eq(Geom, "m32n8k16")),
7442 !or(!eq(PtxEltType, "f16"),
7443 !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]);
7446 !and(!or(!eq(geom, "m8n32k16"),
7447 !eq(geom, "m32n8k16")),
7448 !or(!eq(ptx_elt_type, "f16"),
7449 !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX61],
7450
7451 // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
7452 !and(!or(!eq(geom,"m16n16k16"),
7453 !eq(geom,"m8n32k16"),
7454 !eq(geom,"m32n8k16")),
7455 !or(!eq(ptx_elt_type, "u8"),
7456 !eq(ptx_elt_type, "s8"),
7457 !eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63],
7458
7459 // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
7460 !or(!eq(geom,"m8n8k128"),
7461 !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63]);
74447462
74457463 // template DAGs for instruction inputs/output.
74467464 dag Outs = !dag(outs, ptx_regs, reg_names);
75587576
75597577 // Create all load/store variants
75607578 defset list MMA_LDSTs = {
7561 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
7562 foreach layout = ["row", "col"] in {
7563 foreach stride = [0, 1] in {
7564 foreach space = [".global", ".shared", ""] in {
7565 foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
7566 foreach frag = [WMMA_REGINFO,
7567 WMMA_REGINFO,
7568 WMMA_REGINFO,
7569 WMMA_REGINFO] in {
7570 def : WMMA_LOAD;
7571 }
7572 foreach frag = [WMMA_REGINFO,
7573 WMMA_REGINFO] in {
7574 def : WMMA_STORE_D;
7575 }
7576 } // addr
7577 } // space
7578 } // stride
7579 } // layout
7580 } // geom
7579 foreach layout = ["row", "col"] in {
7580 foreach stride = [0, 1] in {
7581 foreach space = [".global", ".shared", ""] in {
7582 foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
7583 foreach frag = NVVM_MMA_OPS.all_ld_ops in
7584 foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
7585 def : WMMA_LOAD, layout, space, stride, addr>;
7586 foreach frag = NVVM_MMA_OPS.all_st_ops in
7587 foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
7588 def : WMMA_STORE_D, layout, space, stride, addr>;
7589 } // addr
7590 } // space
7591 } // stride
7592 } // layout
75817593 } // defset
75827594
75837595 // WMMA.MMA
75847596 class WMMA_MMA
75857597 WMMA_REGINFO FragC, WMMA_REGINFO FragD,
75867598 string ALayout, string BLayout, int Satfinite>
7587 : WMMA_INSTRFragC, FragD, Satfinite>.record,
7599 : WMMA_INSTRSatfinite, FragA, FragB, FragC, FragD>.record,
75887600 [FragA.Ins, FragB.Ins, FragC.Ins]>,
7589 Requires {
7601 // Requires does not seem to have effect on Instruction w/o Patterns.
7602 // We set it here anyways and propagate to the Pat<> we construct below.
7603 Requires {
75907604 let OutOperandList = FragD.Outs;
75917605 let InOperandList = !con(Args, (ins MmaCode:$ptx));
7592 let AsmString = "wmma.mma.sync"
7606 string TypeList = !cond(
7607 !eq(FragD.ptx_elt_type, "s32") : ".s32"
7608 # "." # FragA.ptx_elt_type
7609 # "." # FragB.ptx_elt_type
7610 # ".s32",
7611 1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type,
7612 );
7613 let AsmString = "wmma.mma"
7614 # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
7615 # ".sync"
75937616 # "${ptx:aligned}"
75947617 # "." # ALayout
75957618 # "." # BLayout
75967619 # "." # FragA.geom
7597 # "." # FragD.ptx_elt_type
7598 # "." # FragC.ptx_elt_type
7620 # TypeList
75997621 # !if(Satfinite, ".satfinite", "") # "\n\t\t"
76007622 # FragD.regstring # ",\n\t\t"
76017623 # FragA.regstring # ",\n\t\t"
76047626 }
76057627
76067628 defset list MMAs = {
7607 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
7608 foreach layout_a = ["row", "col"] in {
7609 foreach layout_b = ["row", "col"] in {
7610 foreach frag_c = [WMMA_REGINFO,
7611 WMMA_REGINFO] in {
7612 foreach frag_d = [WMMA_REGINFO,
7613 WMMA_REGINFO] in {
7614 foreach satf = [0, 1] in {
7615 def : WMMA_MMA,
7616 WMMA_REGINFO,
7617 frag_c, frag_d, layout_a, layout_b, satf>;
7618 } // satf
7619 } // frag_d
7620 } // frag_c
7621 } // layout_b
7622 } // layout_a
7623 } // geom
7629 foreach layout_a = ["row", "col"] in {
7630 foreach layout_b = ["row", "col"] in {
7631 foreach satf = [0, 1] in {
7632 foreach op = NVVM_MMA_OPS.all_mma_ops in {
7633 foreach _ = NVVM_MMA_SUPPORTED.ret in {
7634 def : WMMA_MMA,
7635 WMMA_REGINFO,
7636 WMMA_REGINFO,
7637 WMMA_REGINFO,
7638 layout_a, layout_b, satf>;
7639 }
7640 } // op
7641 } // satf
7642 } // layout_b
7643 } // layout_a
76247644 } // defset
7645
76257646
76267647 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
76277648 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
76297650 class WMMA_PAT
76307651 : Pat
76317652 !con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)),
7632 (wi ptx.version))>;
7653 (wi ptx.version))>,
7654 Requires;
76337655
76347656 // Build intrinsic->instruction patterns for all MMA instructions.
76357657 foreach mma = !listconcat(MMAs, MMA_LDSTs) in
0 # This test generates all variants of wmma intrinsics and verifies that LLVM
11 # generates correct instructions for them.
22
3 # RUN: python %s > %t.ll
4 # RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
5 # RUN: python %s --ptx=63 > %t-ptx63.ll
6 # RUN: llc < %t-ptx63.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx63 | FileCheck %t-ptx63.ll
3 # Check all variants of instructions supported by PTX60 on SM70
4 # RUN: python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
5 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
6 # RUN: --check-prefixes=INTRINSICS,PTX60,SM70
7 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
8 # RUN: --check-prefixes=INTRINSICS,PTX60U,SM70U
9 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
10 # RUN: | FileCheck %t-ptx60-sm_70.ll
11
12 # Check all variants of instructions supported by PTX61 on SM70
13 # RUN: python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
14 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
15 # RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,SM70
16 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
17 # RUN: --check-prefixes=INTRINSICS,PTX61U,SM70U
18 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
19 # RUN: | FileCheck %t-ptx61-sm_70.ll
20
21 # Check all variants of instructions supported by PTX63 on SM72
22 # RUN: python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
23 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
24 # RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72
25 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
26 # RUN: --check-prefixes=INTRINSICS,PTX63U,SM72U
27 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
28 # RUN: | FileCheck %t-ptx63-sm_72.ll
29
30 # Check all variants of instructions supported by PTX63 on SM75
31 # RUN: python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
32 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
33 # RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72,SM75
34 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
35 # RUN: --check-prefixes=INTRINSICS,PTX63U,SM75U
36 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
37 # RUN: | FileCheck %t-ptx63-sm_75.ll
38
739
840 from __future__ import print_function
941
1143 from itertools import product
1244 from string import Template
1345
14 def make_wmma_slice_ty(abcd, itype):
15 elt_ty = "<2 x half>" if itype == "f16" else "float"
16 num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
17 return [elt_ty] * num_elts
18
19 def make_wmma_ld_ret_ty(abc, itype):
20 return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
46 class MMAType:
47 def __init__(self, ptx_type):
48 self.ptx_type = ptx_type
49 self.llvm_type = {
50 "f16" : "<2 x half>",
51 "f32" : "float",
52 "s32" : "i32",
53 "s8" : "i32",
54 "u8" : "i32",
55 "s4" : "i32",
56 "u4" : "i32",
57 "b1" : "i32",
58 }[ptx_type];
59
60 self.ptx_reg_pattern = {
61 "f16" : "%hh[0-9]+",
62 "f32" : "%f[0-9]+",
63 }.get(ptx_type, "%r[0-9]+")
64
65 def __repr__(self):
66 return "%s/%s" % (self.ptx_type, self.llvm_type)
67
68 class MMAFrag:
69 def __init__(self, geom, frag, ptx_elt_type):
70 self.geom = geom
71 self.frag = frag
72 self.mma_type = MMAType(ptx_elt_type);
73 self.nregs = {
74 "a:f16" : 8,
75 "b:f16" : 8,
76 "c:f16" : 4,
77 "d:f16" : 4,
78 "c:f32" : 8,
79 "d:f32" : 8,
80 }.get("%s:%s" % (frag, ptx_elt_type), {
81 # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
82 "m16n16k16:a:u8" : 2,
83 "m16n16k16:a:s8" : 2,
84 "m16n16k16:b:u8" : 2,
85 "m16n16k16:b:s8" : 2,
86 "m16n16k16:c:s32" : 8,
87 "m16n16k16:d:s32" : 8,
88
89 "m8n32k16:a:u8" : 1,
90 "m8n32k16:a:s8" : 1,
91 "m8n32k16:b:u8" : 4,
92 "m8n32k16:b:s8" : 4,
93 "m8n32k16:c:s32" : 8,
94 "m8n32k16:d:s32" : 8,
95
96 "m32n8k16:a:u8" : 4,
97 "m32n8k16:a:s8" : 4,
98 "m32n8k16:b:u8" : 1,
99 "m32n8k16:b:s8" : 1,
100 "m32n8k16:c:s32" : 8,
101 "m32n8k16:d:s32" : 8,
102
103 # u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
104 "m8n8k128:a:b1" : 1,
105 "m8n8k32:a:u4" : 1,
106 "m8n8k32:a:s4" : 1,
107 "m8n8k128:b:b1" : 1,
108 "m8n8k32:b:u4" : 1,
109 "m8n8k32:b:s4" : 1,
110 "m8n8k128:c:s32" : 2,
111 "m8n8k128:d:s32" : 2,
112 "m8n8k32:c:s32" : 2,
113 "m8n8k32:d:s32" : 2,
114 }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None));
115 assert(self.nregs);
116
117 def __repr__(self):
118 return "%s:%s:%s%s" % (self.geom, self.frag, self.mma_type,
119 "" if self.nregs == 1 else ("*%d" % self.nregs))
120
121 class MMAOp:
122 def __init__(self, a, b, c, d):
123 self.a = a
124 self.b = b
125 self.c = c
126 self.d = d
127
128 def __repr__(self):
129 return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d ))
130
131 def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
132 ops = []
133 for geom, type_a, type_c in product( geoms, types_a, types_c):
134 for type_b, type_d in product(types_b if types_b else [type_a],
135 types_d if types_d else [type_c]):
136 ops.append(MMAOp(MMAFrag(geom, "a", type_a),
137 MMAFrag(geom, "b", type_b),
138 MMAFrag(geom, "c", type_c),
139 MMAFrag(geom, "d", type_d)))
140 return ops
141
142 def make_ldst_ops(geoms, frags, types):
143 return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
144 in product(geoms, frags, types)]
145
146 def get_mma_ops():
147 return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
148 ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
149 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
150 ["s8", "u8"], [], ["s32"], []) +
151 make_mma_ops(["m8n8k32"],
152 ["s4", "u4"], [], ["s32"], []) +
153 make_mma_ops(["m8n8k128"],
154 ["b1"], [], ["s32"], []))
155 def get_ldst_ops(kind):
156 ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
157 ["a", "b"], ["f16", "u8", "s8"]) +
158 make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
159 ["c", "d"], ["f16", "f32", "s32"]) +
160 make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
161 make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
162 make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]))
163 return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
164
165 def is_geom_supported(geom):
166 # geometries for FP and ints.
167 if geom in ["m8n32k16", "m32n8k16"]:
168 return ptx_version >= 61
169 # geometries for sub-ints.
170 if geom in ["m8n8k32", "m8n8k128"]:
171 return ptx_version >= 63 and gpu_arch >= 75
172 if geom == "m16n16k16":
173 return ptx_version >= 60
174 assert(False) # Unexpected geometry.
175
176 def is_type_supported(ptx_type):
177 if ptx_type in ["s8", "u8", "s32"]:
178 return ptx_version >= 63 and gpu_arch >= 72
179 if ptx_type in ["s4", "u4", "b1"]:
180 return ptx_version >= 63 and gpu_arch >= 75
181 return ptx_version >= 60 and gpu_arch >= 70
182
183
184 def is_mma_variant_supported(op, layout_a, layout_b, satf):
185 if not (is_type_supported(op.a.mma_type.ptx_type)
186 and is_geom_supported(op.a.geom)):
187 return False
188 # sub-integer require row/col layout, and no satf.
189 if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
190 if op.a.mma_type.ptx_type == "b1" and satf:
191 return False
192 return layout_a == "row" and layout_b == "col"
193 return True
194
195 def is_ldst_variant_supported(frag, layout):
196 if not (is_type_supported(frag.mma_type.ptx_type)
197 and is_geom_supported(frag.geom)):
198 return False
199 if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
200 # sub-integer require sm_75 and ptx63, row/col layout for a/b.
201 return ((frag.frag == "a" and layout == "row")
202 or (frag.frag == "b" and layout == "col")
203 or frag.frag in ["c", "d"])
204 return True
205
206 def make_wmma_slice_ty(frag):
207 return [frag.mma_type.llvm_type] * frag.nregs
208
209 def make_wmma_ld_ret_ty(frag):
210 results = make_wmma_slice_ty(frag)
211 if len(results) == 1:
212 return "%s" % results[0]
213 return "{%s}" % ", ".join(results)
21214
22215 # returns address space
23216 def get_aspace(space):
35228 def get_pspace(space):
36229 return "p%di8" % get_aspace(space);
37230
38 # Convenient test patterns.
39 check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
40 check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
41 check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
231 def check_pattern(frag):
232 return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)
42233
43234 known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]
44235
68259 intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
69260 instruction_template = "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
70261
71 for geom, abc, layout, space, stride, itype in product(
72 known_geoms,
73 "abc",
262 generated_items = []
263
264 for frag, layout, space, stride in product(
265 get_ldst_ops("load"),
74266 ["row","col"],
75267 ["",".shared",".global"],
76268 ["", ".stride"],
77 ["f16", "f32"]):
269 ):
270 if not is_ldst_variant_supported(frag, layout):
271 continue
78272
79273 params = {
80 "abc" : abc,
274 "abc" : frag.frag,
81275 "aligned" : ".aligned" if ptx_version >= 63 else "",
82276 "layout" : layout,
83277 "space" : space,
84278 "stride" : stride,
85 "itype" : itype,
279 "itype" : frag.mma_type.ptx_type,
86280 "pspace" : get_pspace(space),
87281 "as" : "addrspace(%d)" % get_aspace(space),
88 "geom" : geom,
282 "geom" : frag.geom,
89283 }
90
91 if itype == "f32" and abc != "c":
92 continue
93284
94285 test_params = params
95286 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
96287 test_params["function"] = test_params["intrinsic"].replace(".","_")
97288 test_params["instruction"] = Template(instruction_template).substitute(params)
98 test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
99 if abc == "c" :
100 test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
101 else:
102 test_params["check_result"] = check_f16_8
289 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
290 test_params["check_result"] = check_pattern(frag)
103291
104292 if stride:
105293 test_params["extra_args"] = ", i32 %stride";
110298
111299 print(Template(load_template).substitute(test_params))
112300
113 def make_wmma_slice_args(itype, abcd, prefix="v"):
114 return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
115 in enumerate(make_wmma_slice_ty(abcd, itype))])
301 generated_items.append((test_params["intrinsic"],
302 test_params["instruction"]))
303
304 return generated_items
305
306 def make_wmma_slice_args(frag):
307 return ", ".join(["%s %%%s%d" % (t, frag.frag, i) for i,t
308 in enumerate(make_wmma_slice_ty(frag))])
116309
117310 def gen_wmma_store_tests():
118311 store_template = """
140333 intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
141334 instruction_template = "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
142335
143 for geom, abc, layout, space, stride, itype in product(
144 known_geoms,
145 "d",
336 generated_items = []
337
338 for frag, layout, space, stride in product(
339 get_ldst_ops("store"),
146340 ["row","col"],
147341 ["",".shared",".global"],
148 ["", ".stride"],
149 ["f16", "f32"]):
342 ["", ".stride"]):
343
344 if not is_ldst_variant_supported(frag, layout):
345 continue
150346
151347 params = {
152 "abc" : abc,
348 "abc" : frag.frag,
153349 "aligned" : ".aligned" if ptx_version >= 63 else "",
154350 "layout" : layout,
155351 "space" : space,
156352 "stride" : stride,
157 "itype" : itype,
353 "itype" : frag.mma_type.ptx_type,
158354 "pspace" : get_pspace(space),
159355 "as" : "addrspace(%d)" % get_aspace(space),
160 "geom" : geom,
356 "geom" : frag.geom,
161357 }
162358
163359 test_params = params
164360 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
165361 test_params["function"] = test_params["intrinsic"].replace(".","_")
166362 test_params["instruction"] = Template(instruction_template).substitute(params)
167 test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
168 test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
363 test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
364 test_params["check_args"] = check_pattern(frag)
169365 if stride:
170366 test_params["extra_args"] = ", i32 %stride";
171367 test_params["stride_pattern"] = ", %r{{[0-9]+}};"
172368 else:
173369 test_params["extra_args"] = ""
174370 test_params["stride_pattern"] = ";"
175 test_params["args"] = make_wmma_slice_args(itype, "d");
371 test_params["args"] = make_wmma_slice_args(frag);
176372
177373 print(Template(store_template).substitute(test_params))
374 generated_items.append((test_params["intrinsic"],
375 test_params["instruction"]))
376
377 return generated_items
378
379 def mma_signature(op):
380 if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
381 # int and sub-int ops are identified by input type.
382 return op.a.mma_type.ptx_type
383 else:
384 # the rest are FP ops identified by accumulator & result type.
385 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
386
387 def mma_ptx_signature(op):
388 if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
389 # int and sub-int instructions encode all four types as D.A.B.C
390 return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
391 else:
392 # the rest are FP instructions use D.C
393 return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
178394
179395 def gen_wmma_mma_tests():
180396 mma_template = """
186402 ${args}) {
187403 ; CHECK: ${instruction}
188404 ; CHECK-NEXT: ${check_d}
189 ; CHECK-NEXT: ${check_ab}
190 ; CHECK-NEXT: ${check_ab}
405 ; CHECK-NEXT: ${check_a}
406 ; CHECK-NEXT: ${check_b}
191407 ; CHECK-NEXT: ${check_c}
192408 %r = call ${ret_ty} @${intrinsic}(
193409 ${args});
194410 ret ${ret_ty} %r;
195411 }
196412 """
197 intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
198 instruction_template = "wmma.mma.sync${aligned}.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
199
200 for geom, alayout, blayout, ctype, dtype, satf in product(
201 known_geoms,
413 intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
414 instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}"
415
416 generated_items=[]
417
418 for op, alayout, blayout, satf in product(
419 get_mma_ops(),
202420 ["row","col"],
203421 ["row","col"],
204 ["f16", "f32"],
205 ["f16", "f32"],
206422 [".satfinite", ""]):
423
424 if not is_mma_variant_supported(op, alayout, blayout, satf):
425 continue
207426
208427 params = {
209428 "aligned" : ".aligned" if ptx_version >= 63 else "",
210429 "alayout" : alayout,
211430 "blayout" : blayout,
212 "ctype" : ctype,
213 "dtype" : dtype,
431 "intrinsic_signature" : mma_signature(op),
432 "ptx_signature" : mma_ptx_signature(op),
214433 "satf" : satf,
215 "geom" : geom,
434 "geom" : op.a.geom,
435 "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
216436 }
217437
218438 test_params = params
219439 test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
220440 test_params["function"] = test_params["intrinsic"].replace(".", "_")
221441 test_params["instruction"] = Template(instruction_template).substitute(params)
222 test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
223 test_params["check_ab"] = check_f16_8
224 test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
225 test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
226 args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
227 for abcd, t in (("a", "f16"),
228 ("b", "f16"),
229 ("c", ctype)))
442 test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
443 test_params["check_a"] = check_pattern(op.a)
444 test_params["check_b"] = check_pattern(op.b)
445 test_params["check_c"] = check_pattern(op.c)
446 test_params["check_d"] = check_pattern(op.d)
447 args = ",\n ".join(make_wmma_slice_args(frag)
448 for frag in (op.a, op.b, op.c))
230449 test_params["args"] = args
231450 print(Template(mma_template).substitute(test_params))
232
233 def main():
234 gen_wmma_load_tests()
235 gen_wmma_store_tests()
236 gen_wmma_mma_tests()
451 generated_items.append((test_params["intrinsic"],
452 test_params["instruction"]))
453
454 return generated_items
455
456 # Append complete list of intrinsics and instructions we've generated tests for.
457 # Generate set of checks to verify that that we did generate sensible set of
458 # tests for the given combination of PTX and SM variants.
459 #
460 # PTX: verifies that we did generate tests for correct classes of intrinsics.
461 # PTXU: verifies that we did not generate intrinsics unsupported by
462 # the PTX version.
463 # SM: verifies that we did generate correct classes of instructions for the SM.
464 # SMU: verifies that we did not generate instructions unsupported by the SM
465 #
466 # Note that SM/PTX constraints overlap, but DAG checks do not allow overlapping
467 # matches. We implicitly rely that we generate multiple variants of most of the
468 # instructions and usually have enough input data to find more than one match of
469 # the same kind, if necessary. When it's not possible (e.g. there's only one
470 # m8n8k128.mma.row.col.b1), we may need to match PTX instruction instead.
471 def gen_check_unsupported_ops(items):
472 print("; Complete list of intrinsics supported by PTX%d on sm_%d"
473 % (ptx_version, gpu_arch))
474 print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
475 print("""
476 ; PTX60-DAG: m16n16k16.load.{{[ab].*}}.f16.p
477 ; PTX60-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
478 ; PTX60U-NOT: m32n8k16
479 ; PTX60U-NOT: m8n32k16
480 ; PTX60U-NOT: .{{s32|s[48]|u[48]|b1}}
481
482 ; All features of PTX60, plus m32n8k16/m8n32k16 geometries.
483 ; PTX61-DAG: m32n8k16.load.{{[ab].*}}.f16.p
484 ; PTX61-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
485 ; PTX61-DAG: m8n32k16.load.{{[ab].*}}.f16.p
486 ; PTX61-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
487 ; PTX61U-NOT: .{{s32|s[48]|u[48]|b1}}
488
489 ; SM70U-NOT: .{{s32|s[48]|u[48]|b1}}
490
491 ; PTX63 supports all features of PTX60+PTX61, plus support for integers.
492 ; Alas we can"t just use PTX checks for that as available instructions
493 ; depend on SM integers need sm72+ and subinteger ops need sm75, so we
494 ; transition to SM checks
495 ; SM72-DAG: m16n16k16.load.{{[ab].*}}.s8.p
496 ; SM72-DAG: m8n32k16.load.{{[ab].*}}.s8.p
497 ; SM72-DAG: m32n8k16.load.{{[ab].*}}.s8.p
498 ; SM72-DAG: m16n16k16.load.{{[ab].*}}.u8.p
499 ; SM72-DAG: m8n32k16.load.{{[ab].*}}.u8.p
500 ; SM72-DAG: m32n8k16.load.{{[ab].*}}.u8.p
501 ; SM72-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
502 ; SM72U-NOT: .{{s4|u4|b1}}
503
504 ; SM75-DAG: m8n8k128.load.{{[ab].*}}.b1.p
505 ; SM75-DAG: m8n8k32.load.{{[ab].*}}.s4.p
506 ; SM75-DAG: m8n8k32.load.{{[ab].*}}.u4.p
507 ; SM75-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
508 ; SM75-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
509 """)
510
511 print("; INTRINSICS_LIST_BEGIN")
512 for intrinsic, instruction in sorted(items):
513 print("; ", intrinsic, " -> ", instruction,"")
514 print("; INTRINSICS_LIST_END")
515 print("; INTRINSICS: ; INTRINSICS_LIST_END")
516
517 def gen_tests():
518 items = gen_wmma_load_tests()
519 items += gen_wmma_store_tests()
520 items += gen_wmma_mma_tests()
521 gen_check_unsupported_ops(items)
237522
238523 parser = argparse.ArgumentParser()
239 parser.add_argument('--ptx', type=int, default=60)
524 parser.add_argument("--ptx", type=int, default=60)
525 parser.add_argument("--gpu-arch", type=int, default=70)
240526 args = parser.parse_args()
241527 ptx_version = args.ptx
242
243 main()
528 gpu_arch = args.gpu_arch
529
530 gen_tests()