llvm.org GIT mirror llvm / 5446c67
[NVPTX] generate correct MMA instruction mnemonics with PTX63+. PTX 6.3 requires using ".aligned" in the MMA instruction names. In order to generate correct name, now we pass current PTX version to each instruction as an extra constant operand and InstPrinter adjusts its output accordingly. Differential Revision: https://reviews.llvm.org/D59393 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@359246 91177308-0d34-0410-b5e6-96231b3b80d8 Artem Belevich 1 year, 7 months ago
5 changed file(s) with 185 addition(s) and 133 deletion(s). Raw diff Collapse all Expand all
268268 llvm_unreachable("Empty Modifier");
269269 }
270270
271 void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O,
272 const char *Modifier) {
273 const MCOperand &MO = MI->getOperand(OpNum);
274 int Imm = (int)MO.getImm();
275 if (Modifier == nullptr || strcmp(Modifier, "version") == 0) {
276 O << Imm; // Just print out PTX version
277 } else if (strcmp(Modifier, "aligned") == 0) {
278 // PTX63 requires '.aligned' in the name of the instruction.
279 if (Imm >= 63)
280 O << ".aligned";
281 } else
282 llvm_unreachable("Unknown Modifier");
283 }
284
271285 void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
272286 raw_ostream &O, const char *Modifier) {
273287 printOperand(MI, OpNum, O);
3939 const char *Modifier = nullptr);
4040 void printLdStCode(const MCInst *MI, int OpNum,
4141 raw_ostream &O, const char *Modifier = nullptr);
42 void printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O,
43 const char *Modifier = nullptr);
4244 void printMemOperand(const MCInst *MI, int OpNum,
4345 raw_ostream &O, const char *Modifier = nullptr);
4446 void printProtoIdent(const MCInst *MI, int OpNum,
15451545
15461546 def LdStCode : Operand {
15471547 let PrintMethod = "printLdStCode";
1548 }
1549
1550 def MmaCode : Operand {
1551 let PrintMethod = "printMmaCode";
15481552 }
15491553
15501554 def SDTWrapper : SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>, SDTCisPtrTy<0>]>;
3535 code global = [{
3636 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
3737 }];
38 }
39
40 // A node that will be replaced with the current PTX version.
41 class PTX {
42 SDNodeXForm PTXVerXform = SDNodeXForm
43 return getI32Imm(Subtarget->getPTXVersion(), SDLoc(N));
44 }]>;
45 // (i32 0) will be XForm'ed to the currently used PTX version.
46 dag version = (PTXVerXform (i32 0));
47 }
48 def ptx : PTX;
49
50 // Generates list of n sequential register names.
51 // E.g. RegNames<3,"r">.ret -> ["r0", "r1", "r2" ]
52 class RegSeq {
53 list ret = !if(n, !listconcat(RegSeq.ret,
54 [prefix # !add(n, -1)]),
55 []);
3856 }
3957
4058 //-----------------------------------
73837401 NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;",
73847402 [(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>;
73857403
7386 class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>;
7387 // Generates list of n sequential register names.
7388 class RegSeq {
7389 list ret = !if(n, !listconcat(RegSeq.ret,
7390 [prefix # !add(n, -1)]),
7391 []);
7392 }
7393
73947404 // Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
73957405 // In addition to target-independent fields provided by WMMA_REGS, it adds
73967406 // the fields commonly used to implement specific PTX instruction -- register
74077417
74087418 // List of register names for the fragment -- ["ra0", "ra1",...]
74097419 list reg_names = RegSeq.ret;
7420
74107421 // Generates "{{$r0, $r1,.... $rN-1}}" for use in asm string construction.
74117422 string regstring = "{{$" # !head(reg_names)
74127423 # !foldl("", !tail(reg_names), a, b,
74367447 dag Ins = !dag(ins, ptx_regs, reg_names);
74377448 }
74387449
7439 class BuildPattern {
7450 // Convert dag of arguments into a dag to match given intrinsic.
7451 class BuildPatternI {
74407452 // Build a dag pattern that matches the intrinsic call.
7441 // We want a dag that looks like this:
7442 // (set , (intrinsic )) where input and
7443 // output arguments are named patterns that would match corresponding
7444 // input/output arguments of the instruction.
7445 //
7446 // First we construct (set ) from instruction's outs dag by
7447 // replacing dag operator 'outs' with 'set'.
7448 dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp));
7449 // Similarly, construct (intrinsic ) sub-dag from
7450 // instruction's input arguments, only now we also need to replace operands
7451 // with patterns that would match them and the operator 'ins' with the
7452 // intrinsic.
7453 dag PatArgs = !foreach(tmp, Ins,
7454 !subst(imem, ADDRvar,
7455 !subst(MEMri64, ADDRri64,
7456 !subst(MEMri, ADDRri,
7457 !subst(ins, IntrMatcher, tmp)))));
7458 // Finally, consatenate both parts together. !con() requires both dags to have
7459 // the same operator, so we wrap PatArgs in a (set ...) dag.
7460 dag ret = !con(PatOuts, (set PatArgs));
7453 dag ret = !foreach(tmp, Ins,
7454 !subst(imem, ADDRvar,
7455 !subst(MEMri64, ADDRri64,
7456 !subst(MEMri, ADDRri,
7457 !subst(ins, Intr, tmp)))));
7458 }
7459
7460 // Same as above, but uses PatFrag instead of an Intrinsic.
7461 class BuildPatternPF {
7462 // Build a dag pattern that matches the intrinsic call.
7463 dag ret = !foreach(tmp, Ins,
7464 !subst(imem, ADDRvar,
7465 !subst(MEMri64, ADDRri64,
7466 !subst(MEMri, ADDRri,
7467 !subst(ins, Intr, tmp)))));
7468 }
7469
7470 // Common WMMA-related fields used for building patterns for all MMA instructions.
7471 class WMMA_INSTR _Args>
7472 : NVPTXInst<(outs), (ins), "?", []> {
7473 Intrinsic Intr = !cast(_Intr);
7474 // Concatenate all arguments into a single dag.
7475 dag Args = !foldl((ins), _Args, a, b, !con(a,b));
7476 // Pre-build the pattern to match (intrinsic arg0, arg1, ...).
7477 dag IntrinsicPattern = BuildPatternI(Intr), Args>.ret;
74617478 }
74627479
74637480 //
74647481 // wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
74657482 //
74667483
7467 class WMMA_LOAD_INTR_HELPER
7468 bit WithStride>
7469 : PatFrag <(ops),(ops)> {
7470 // Intrinsic that matches this instruction.
7471 Intrinsic Intr = !cast(WMMA_NAME_LDST<"load", Frag, Layout,
7472 WithStride>.record);
7473 let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
7474 let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
7475 let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared,
7476 !eq(Space, ".global"): AS_match.global,
7477 1: AS_match.generic);
7478 }
7479
74807484 class WMMA_LOAD
74817485 DAGOperand SrcOp>
7482 : EmptyNVPTXInst,
7486 : WMMA_INSTR.record,
7487 [!con((ins SrcOp:$src),
7488 !if(WithStride, (ins Int32Regs:$ldm), (ins)))]>,
74837489 Requires {
7484 // Pattern that matches the intrinsic for this instruction variant.
7485 PatFrag IntrMatcher = WMMA_LOAD_INTR_HELPER;
7486 dag Ins = !con((ins SrcOp:$src), !if(WithStride, (ins Int32Regs:$ldm), (ins)));
7487
7488 let Pattern = [BuildPattern.ret];
7490 // Load/store intrinsics are overloaded on pointer's address space.
7491 // To match the right intrinsic, we need to build AS-constrained PatFrag.
7492 // Operands is a dag equivalent in shape to Args, but using (ops node:$name, .....).
7493 dag PFOperands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
7494 // Build PatFrag that only matches particular address space.
7495 PatFrag IntrFrag = PatFrag
7496 !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)),
7497 !cond(!eq(Space, ".shared"): AS_match.shared,
7498 !eq(Space, ".global"): AS_match.global,
7499 1: AS_match.generic)>;
7500 // Build AS-constrained pattern.
7501 let IntrinsicPattern = BuildPatternPF.ret;
7502
74897503 let OutOperandList = Frag.Outs;
7490 let InOperandList = Ins;
7504 let InOperandList = !con(Args, (ins MmaCode:$ptx));
74917505 let AsmString = "wmma.load."
74927506 # Frag.frag
74937507 # ".sync"
7508 # "${ptx:aligned}"
74947509 # "." # Layout
74957510 # "." # Frag.geom
74967511 # Space
75047519 //
75057520 // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
75067521 //
7507 class WMMA_STORE_INTR_HELPER
7508 bit WithStride>
7509 : PatFrag <(ops),(ops)> {
7510 // Intrinsic that matches this instruction.
7511 Intrinsic Intr = !cast(WMMA_NAME_LDST<"store", Frag, Layout,
7512 WithStride>.record);
7513 let Operands = !con((ops node:$dst),
7514 !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names),
7515 !if(WithStride, (ops node:$ldm), (ops)));
7516 let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
7517 let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared,
7518 !eq(Space, ".global"): AS_match.global,
7519 1: AS_match.generic);
7520 }
7521
7522 class WMMA_STORE
7523 DAGOperand DstOp>
7524 : EmptyNVPTXInst,
7522 class WMMA_STORE_D,
7523 bit WithStride, DAGOperand DstOp>
7524 : WMMA_INSTR.record,
7525 [!con((ins DstOp:$dst),
7526 Frag.Ins,
7527 !if(WithStride, (ins Int32Regs:$ldm), (ins)))]>,
75257528 Requires {
7526 PatFrag IntrMatcher = WMMA_STORE_INTR_HELPER;
7527 dag Ins = !con((ins DstOp:$src),
7528 Frag.Ins,
7529 !if(WithStride, (ins Int32Regs:$ldm), (ins)));
7530 let Pattern = [BuildPattern<(set), IntrMatcher, Ins>.ret];
7529
7530 // Load/store intrinsics are overloaded on pointer's address space.
7531 // To match the right intrinsic, we need to build AS-constrained PatFrag.
7532 // Operands is a dag equivalent in shape to Args, but using (ops node:$name, .....).
7533 dag PFOperands = !con((ops node:$dst),
7534 !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names),
7535 !if(WithStride, (ops node:$ldm), (ops)));
7536 // Build PatFrag that only matches particular address space.
7537 PatFrag IntrFrag = PatFrag
7538 !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)),
7539 !cond(!eq(Space, ".shared"): AS_match.shared,
7540 !eq(Space, ".global"): AS_match.global,
7541 1: AS_match.generic)>;
7542 // Build AS-constrained pattern.
7543 let IntrinsicPattern = BuildPatternPF.ret;
7544
7545 let InOperandList = !con(Args, (ins MmaCode:$ptx));
75317546 let OutOperandList = (outs);
7532 let InOperandList = Ins;
7533 let AsmString = "wmma.store.d.sync."
7534 # Layout
7547 let AsmString = "wmma.store.d.sync"
7548 # "${ptx:aligned}"
7549 # "." # Layout
75357550 # "." # Frag.geom
75367551 # Space
75377552 # "." # Frag.ptx_elt_type
7538 # " \t[$src],"
7553 # " \t[$dst],"
75397554 # Frag.regstring
75407555 # !if(WithStride, ", $ldm", "")
75417556 # ";";
75427557 }
75437558
75447559 // Create all load/store variants
7545 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
7546 foreach layout = ["row", "col"] in {
7547 foreach stride = [0, 1] in {
7548 foreach space = [".global", ".shared", ""] in {
7549 foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
7550 foreach frag = [WMMA_REGINFO,
7551 WMMA_REGINFO,
7552 WMMA_REGINFO,
7553 WMMA_REGINFO] in {
7554 def : WMMA_LOAD;
7555 }
7556 foreach frag = [WMMA_REGINFO,
7557 WMMA_REGINFO] in {
7558 def : WMMA_STORE;
7559 }
7560 } // addr
7561 } // space
7562 } // stride
7563 } // layout
7564 } // geom
7560 defset list MMA_LDSTs = {
7561 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
7562 foreach layout = ["row", "col"] in {
7563 foreach stride = [0, 1] in {
7564 foreach space = [".global", ".shared", ""] in {
7565 foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
7566 foreach frag = [WMMA_REGINFO,
7567 WMMA_REGINFO,
7568 WMMA_REGINFO,
7569 WMMA_REGINFO] in {
7570 def : WMMA_LOAD;
7571 }
7572 foreach frag = [WMMA_REGINFO,
7573 WMMA_REGINFO] in {
7574 def : WMMA_STORE_D;
7575 }
7576 } // addr
7577 } // space
7578 } // stride
7579 } // layout
7580 } // geom
7581 } // defset
75657582
75667583 // WMMA.MMA
75677584 class WMMA_MMA
75687585 WMMA_REGINFO FragC, WMMA_REGINFO FragD,
75697586 string ALayout, string BLayout, int Satfinite>
7570 : EmptyNVPTXInst,
7587 : WMMA_INSTR.record,
7588 [FragA.Ins, FragB.Ins, FragC.Ins]>,
75717589 Requires {
7572 //Intrinsic Intr = int_nvvm_suld_1d_v4i32_zero;
7573 Intrinsic Intr = !cast(WMMA_NAME_MMA.record);
7574 dag Outs = FragD.Outs;
7575 dag Ins = !con(FragA.Ins,
7576 FragB.Ins,
7577 FragC.Ins);
7578
7579 // Construct the pattern to match corresponding intrinsic call.
7580 // mma does not load/store anything, so we don't need complex operand matching here.
7581 dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp));
7582 dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp));
7583 let Pattern = [!con(PatOuts, (set PatArgs))];
7584 let OutOperandList = Outs;
7585 let InOperandList = Ins;
7586 let AsmString = "wmma.mma.sync."
7587 # ALayout
7590 let OutOperandList = FragD.Outs;
7591 let InOperandList = !con(Args, (ins MmaCode:$ptx));
7592 let AsmString = "wmma.mma.sync"
7593 # "${ptx:aligned}"
7594 # "." # ALayout
75887595 # "." # BLayout
75897596 # "." # FragA.geom
75907597 # "." # FragD.ptx_elt_type
75967603 # FragC.regstring # ";";
75977604 }
75987605
7599 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
7600 foreach layout_a = ["row", "col"] in {
7601 foreach layout_b = ["row", "col"] in {
7602 foreach frag_c = [WMMA_REGINFO,
7603 WMMA_REGINFO] in {
7604 foreach frag_d = [WMMA_REGINFO,
7605 WMMA_REGINFO] in {
7606 foreach satf = [0, 1] in {
7607 def : WMMA_MMA,
7608 WMMA_REGINFO,
7609 frag_c, frag_d, layout_a, layout_b, satf>;
7610 } // satf
7611 } // frag_d
7612 } // frag_c
7613 } // layout_b
7614 } // layout_a
7615 } // geom
7606 defset list MMAs = {
7607 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
7608 foreach layout_a = ["row", "col"] in {
7609 foreach layout_b = ["row", "col"] in {
7610 foreach frag_c = [WMMA_REGINFO,
7611 WMMA_REGINFO] in {
7612 foreach frag_d = [WMMA_REGINFO,
7613 WMMA_REGINFO] in {
7614 foreach satf = [0, 1] in {
7615 def : WMMA_MMA,
7616 WMMA_REGINFO,
7617 frag_c, frag_d, layout_a, layout_b, satf>;
7618 } // satf
7619 } // frag_d
7620 } // frag_c
7621 } // layout_b
7622 } // layout_a
7623 } // geom
7624 } // defset
7625
7626 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
7627 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
7628 // the instruction record.
7629 class WMMA_PAT
7630 : Pat
7631 !con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)),
7632 (wi ptx.version))>;
7633
7634 // Build intrinsic->instruction patterns for all MMA instructions.
7635 foreach mma = !listconcat(MMAs, MMA_LDSTs) in
7636 def : WMMA_PAT;
22
33 # RUN: python %s > %t.ll
44 # RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
5 # RUN: python %s --ptx=63 > %t-ptx63.ll
6 # RUN: llc < %t-ptx63.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx63 | FileCheck %t-ptx63.ll
57
68 from __future__ import print_function
79
10 import argparse
811 from itertools import product
912 from string import Template
1013
6366 }
6467 """
6568 intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
66 instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}"
69 instruction_template = "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
6770
6871 for geom, abc, layout, space, stride, itype in product(
6972 known_geoms,
7578
7679 params = {
7780 "abc" : abc,
81 "aligned" : ".aligned" if ptx_version >= 63 else "",
7882 "layout" : layout,
7983 "space" : space,
8084 "stride" : stride,
134138 }
135139 """
136140 intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
137 instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}"
141 instruction_template = "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
138142
139143 for geom, abc, layout, space, stride, itype in product(
140144 known_geoms,
146150
147151 params = {
148152 "abc" : abc,
153 "aligned" : ".aligned" if ptx_version >= 63 else "",
149154 "layout" : layout,
150155 "space" : space,
151156 "stride" : stride,
190195 }
191196 """
192197 intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
193 instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
198 instruction_template = "wmma.mma.sync${aligned}.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
194199
195200 for geom, alayout, blayout, ctype, dtype, satf in product(
196201 known_geoms,
201206 [".satfinite", ""]):
202207
203208 params = {
209 "aligned" : ".aligned" if ptx_version >= 63 else "",
204210 "alayout" : alayout,
205211 "blayout" : blayout,
206212 "ctype" : ctype,
229235 gen_wmma_store_tests()
230236 gen_wmma_mma_tests()
231237
238 parser = argparse.ArgumentParser()
239 parser.add_argument('--ptx', type=int, default=60)
240 args = parser.parse_args()
241 ptx_version = args.ptx
242
232243 main()