llvm.org GIT mirror llvm / 5100f1a
[NVPTX] Refactor generation of MMA intrinsics and instructions. NFC. Generalized constructions of 'fragments' of MMA operations to provide common primitives for construction of the ops. This will make it easier to add new variants of the instructions that operate on integer types. Use nested foreach loops which makes it possible to better control naming of the intrinsics. This patch does not affect LLVM's output, so there are no test changes. Differential Revision: https://reviews.llvm.org/D59389 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@359245 91177308-0d34-0410-b5e6-96231b3b80d8 Artem Belevich 1 year, 7 months ago
2 changed file(s) with 301 addition(s) and 481 deletion(s). Raw diff Collapse all Expand all
3535 //
3636 // MISC
3737 //
38
39 // Helper class for construction of n-element list [t,t,...,t]
40 class RepLLVMType {
41 list ret = !if(N, !listconcat(RepLLVMType.ret, [T]), []);
42 }
43
44 // Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
45 // Geom: mnk. E.g. m8n32k16
46 // Frag: [abcd]
47 // PtxEltType: PTX type for the element.
48 class WMMA_REGS {
49 string geom = Geom;
50 string frag = Frag;
51 string ptx_elt_type = PtxEltType;
52 string ft = frag#":"#ptx_elt_type;
53 list regs = !cond(
54 // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
55 // All currently supported geometries use the same fragment format,
56 // so we only need to consider {fragment, type}.
57 !eq(ft,"a:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret,
58 !eq(ft,"b:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret,
59 !eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
60 !eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
61 !eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret,
62 !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret);
63 }
64
65 class WMMA_NAME_LDST {
66 string intr = "llvm.nvvm.wmma."
67 # Frag.geom
68 # "." # Op
69 # "." # Frag.frag
70 # "." # Layout
71 # !if(WithStride, ".stride", "")
72 # "." # Frag.ptx_elt_type
73 ;
74 // TODO(tra): record name should ideally use the same field order as the intrinsic.
75 // E.g. string record = !subst("llvm", "int",
76 // !subst(".", "_", llvm));
77 string record = "int_nvvm_wmma_"
78 # Frag.geom
79 # "_" # Op
80 # "_" # Frag.frag
81 # "_" # Frag.ptx_elt_type
82 # "_" # Layout
83 # !if(WithStride, "_stride", "");
84 }
85
86 class WMMA_NAME_MMA
87 WMMA_REGS C, WMMA_REGS D,
88 int Satfinite> {
89 string llvm = "llvm.nvvm.wmma."
90 # C.geom
91 # ".mma"
92 # "." # ALayout
93 # "." # BLayout
94 # "." # D.ptx_elt_type // Intrinsic encodes 'd' first.
95 # "." # C.ptx_elt_type
96 # !if(Satfinite, ".satfinite", "");
97
98 string record = !subst(".", "_",
99 !subst("llvm.", "int_", llvm));
100 }
38101
39102 let TargetPrefix = "nvvm" in {
40103 def int_nvvm_prmt : GCCBuiltin<"__nvvm_prmt">,
38883951 //
38893952 // WMMA instructions
38903953 //
3891
38923954 // WMMA.LOAD
3893 class NVVM_WMMA_LD_GALSTS
3894 string Type, LLVMType regty, int WithStride>
3895 : Intrinsic
3896 [regty, regty, regty, regty],
3897 [regty, regty, regty, regty,
3898 regty, regty, regty, regty]),
3955 class NVVM_WMMA_LD
3956 : Intrinsic
38993957 !if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
39003958 [IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
3901 "llvm.nvvm.wmma."
3902 # Geometry
3903 # ".load"
3904 # "." # Abc
3905 # "." # Layout
3906 # !if(WithStride, ".stride", "")
3907 # "." # Type>;
3908
3909 multiclass NVVM_WMMA_LD_GALT
3910 string Type, LLVMType regty> {
3911 def _stride: NVVM_WMMA_LD_GALSTS;
3912 def NAME : NVVM_WMMA_LD_GALSTS;
3913 }
3914
3915 multiclass NVVM_WMMA_LD_GAT
3916 string Type, LLVMType regty> {
3917 defm _row: NVVM_WMMA_LD_GALT;
3918 defm _col: NVVM_WMMA_LD_GALT;
3919 }
3920
3921 multiclass NVVM_WMMA_LD_G {
3922 defm _a_f16: NVVM_WMMA_LD_GAT;
3923 defm _b_f16: NVVM_WMMA_LD_GAT;
3924 defm _c_f16: NVVM_WMMA_LD_GAT;
3925 defm _c_f32: NVVM_WMMA_LD_GAT;
3926 }
3927
3928 multiclass NVVM_WMMA_LD {
3929 defm _m32n8k16_load: NVVM_WMMA_LD_G<"m32n8k16">;
3930 defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">;
3931 defm _m8n32k16_load: NVVM_WMMA_LD_G<"m8n32k16">;
3932 }
3933
3934 defm int_nvvm_wmma: NVVM_WMMA_LD;
3959 WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.intr>;
39353960
39363961 // WMMA.STORE.D
3937 class NVVM_WMMA_STD_GLSTS
3938 string Type, LLVMType regty, int WithStride,
3939 // This is only used to create a typed empty array we
3940 // need to pass to !if below.
3941 listEmpty=[]>
3962 class NVVM_WMMA_ST>
39423963 : Intrinsic<[],
39433964 !listconcat(
39443965 [llvm_anyptr_ty],
3945 !if(!eq(Type,"f16"),
3946 [regty, regty, regty, regty],
3947 [regty, regty, regty, regty,
3948 regty, regty, regty, regty]),
3949 !if(WithStride, [llvm_i32_ty], Empty)),
3966 Frag.regs,
3967 !if(WithStride, [llvm_i32_ty], [])),
39503968 [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
3951 "llvm.nvvm.wmma."
3952 # Geometry
3953 # ".store.d"
3954 # "." # Layout
3955 # !if(WithStride, ".stride", "")
3956 # "." # Type>;
3957
3958 multiclass NVVM_WMMA_STD_GLT
3959 string Type, LLVMType regty> {
3960 def _stride: NVVM_WMMA_STD_GLSTS;
3961 def NAME: NVVM_WMMA_STD_GLSTS>;
3969 WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>;
3970
3971 // Create all load/store variants
3972 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
3973 foreach layout = ["row", "col"] in {
3974 foreach stride = [0, 1] in {
3975 foreach frag = [WMMA_REGS,
3976 WMMA_REGS,
3977 WMMA_REGS,
3978 WMMA_REGS] in {
3979 def WMMA_NAME_LDST<"load", frag, layout, stride>.record
3980 : NVVM_WMMA_LD;
3981 }
3982 foreach frag = [WMMA_REGS,
3983 WMMA_REGS] in {
3984 def WMMA_NAME_LDST<"store", frag, layout, stride>.record
3985 : NVVM_WMMA_ST;
3986 }
3987 }
3988 }
39623989 }
39633990
3964 multiclass NVVM_WMMA_STD_GT {
3965 defm _row: NVVM_WMMA_STD_GLT;
3966 defm _col: NVVM_WMMA_STD_GLT;
3991 // WMMA.MMA
3992 class NVVM_WMMA_MMA
3993 WMMA_REGS C, WMMA_REGS D, int Satfinite>
3994 : Intrinsic
3995 !listconcat(
3996 WMMA_REGS.regs,
3997 WMMA_REGS.regs,
3998 C.regs),
3999 [IntrNoMem],
4000 WMMA_NAME_MMA.llvm>;
4001
4002 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
4003 foreach layout_a = ["row", "col"] in {
4004 foreach layout_b = ["row", "col"] in {
4005 foreach frag_c = [WMMA_REGS,
4006 WMMA_REGS] in {
4007 foreach frag_d = [WMMA_REGS,
4008 WMMA_REGS] in {
4009 foreach satf = [0, 1] in {
4010 def WMMA_NAME_MMA.record
4011 : NVVM_WMMA_MMA;
4012 }
4013 }
4014 }
4015 }
4016 }
39674017 }
3968 multiclass NVVM_WMMA_STD_G {
3969 defm _d_f16: NVVM_WMMA_STD_GT;
3970 defm _d_f32: NVVM_WMMA_STD_GT;
3971 }
3972
3973 multiclass NVVM_WMMA_STD {
3974 defm _m32n8k16_store: NVVM_WMMA_STD_G<"m32n8k16">;
3975 defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">;
3976 defm _m8n32k16_store: NVVM_WMMA_STD_G<"m8n32k16">;
3977 }
3978
3979 defm int_nvvm_wmma: NVVM_WMMA_STD;
3980
3981 // WMMA.MMA
3982 class NVVM_WMMA_MMA_GABDCS
3983 string ALayout, string BLayout,
3984 string DType, LLVMType d_regty,
3985 string CType, LLVMType c_regty,
3986 string Satfinite = "">
3987 : Intrinsic
3988 [d_regty, d_regty, d_regty, d_regty],
3989 [d_regty, d_regty, d_regty, d_regty,
3990 d_regty, d_regty, d_regty, d_regty]),
3991 !listconcat(
3992 [// A
3993 llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
3994 llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
3995 // B
3996 llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
3997 llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
3998 !if(!eq(CType,"f16"),
3999 [c_regty, c_regty, c_regty, c_regty],
4000 [c_regty, c_regty, c_regty, c_regty,
4001 c_regty, c_regty, c_regty, c_regty])),
4002 [IntrNoMem],
4003 "llvm.nvvm.wmma."
4004 # Geometry
4005 # ".mma"
4006 # "." # ALayout
4007 # "." # BLayout
4008 # "." # DType
4009 # "." # CType
4010 # Satfinite> {
4011 }
4012
4013 multiclass NVVM_WMMA_MMA_GABDC
4014 string DType, LLVMType d_regty,
4015 string CType, LLVMType c_regty> {
4016 def NAME : NVVM_WMMA_MMA_GABDCS
4017 DType, d_regty, CType, c_regty>;
4018 def _satfinite: NVVM_WMMA_MMA_GABDCS
4019 DType, d_regty, CType, c_regty,".satfinite">;
4020 }
4021
4022 multiclass NVVM_WMMA_MMA_GABD
4023 string DType, LLVMType d_regty> {
4024 defm _f16: NVVM_WMMA_MMA_GABDC
4025 "f16", llvm_v2f16_ty>;
4026 defm _f32: NVVM_WMMA_MMA_GABDC
4027 "f32", llvm_float_ty>;
4028 }
4029
4030 multiclass NVVM_WMMA_MMA_GAB {
4031 defm _f16: NVVM_WMMA_MMA_GABD;
4032 defm _f32: NVVM_WMMA_MMA_GABD;
4033 }
4034
4035 multiclass NVVM_WMMA_MMA_GA {
4036 defm _col: NVVM_WMMA_MMA_GAB;
4037 defm _row: NVVM_WMMA_MMA_GAB;
4038 }
4039
4040 multiclass NVVM_WMMA_MMA_G {
4041 defm _col: NVVM_WMMA_MMA_GA;
4042 defm _row: NVVM_WMMA_MMA_GA;
4043 }
4044
4045 multiclass NVVM_WMMA_MMA {
4046 defm _m32n8k16_mma : NVVM_WMMA_MMA_G<"m32n8k16">;
4047 defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">;
4048 defm _m8n32k16_mma : NVVM_WMMA_MMA_G<"m8n32k16">;
4049 }
4050
4051 defm int_nvvm_wmma : NVVM_WMMA_MMA;
40524018
40534019 } // let TargetPrefix = "nvvm"
2525 return (d==1.0);
2626 }]>;
2727
28
28 def AS_match {
29 code generic = [{
30 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
31 }];
32 code shared = [{
33 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
34 }];
35 code global = [{
36 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
37 }];
38 }
2939
3040 //-----------------------------------
3141 // Synchronization and shuffle functions
10051015 //-----------------------------------
10061016
10071017 class ATOMIC_GLOBAL_CHK
1008 : PatFrag
1009 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
1010 }]>;
1018 : PatFrag>;
10111019 class ATOMIC_SHARED_CHK
1012 : PatFrag
1013 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
1014 }]>;
1020 : PatFrag>;
10151021 class ATOMIC_GENERIC_CHK
1016 : PatFrag
1017 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
1018 }]>;
1022 : PatFrag>;
10191023
10201024 multiclass F_ATOMIC_2_imp
10211025 string SpaceStr, string TypeStr, string OpcStr, PatFrag IntOp,
73797383 NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;",
73807384 [(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>;
73817385
7382 //
7383 // wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
7384 //
7385
73867386 class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>;
7387
7388 class WMMA_LOAD_GALSTOS
7389 string Space, string Type, NVPTXRegClass regclass,
7390 DAGOperand SrcOp, bit WithStride>
7391 : EmptyNVPTXInst,
7392 Requires<[!if(!eq(Geometry, "m16n16k16"),
7393 hasPTX60,
7394 hasPTX61),
7395 hasSM70]> {
7396 // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
7397 // for this function.
7398 PatFrag IntrMatcher = !cast("INT_WMMA_"
7399 # Geometry # "_load_"
7400 # !subst("c", "c_" # Type, Abc)
7401 # "_" # Layout
7402 # !subst(".", "_", Space)
7403 # !if(WithStride,"_stride", "")
7404 # "_Intr");
7405 dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
7406 dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
7407 dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47));
7408
7409 dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
7410 dag Ins = !con((ins SrcOp:$src), StrideArg);
7411
7387 // Generates list of n sequential register names.
7388 class RegSeq {
7389 list ret = !if(n, !listconcat(RegSeq.ret,
7390 [prefix # !add(n, -1)]),
7391 []);
7392 }
7393
7394 // Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
7395 // In addition to target-independent fields provided by WMMA_REGS, it adds
7396 // the fields commonly used to implement specific PTX instruction -- register
7397 // types and names, constraints, parts of assembly, etc.
7398 class WMMA_REGINFO
7399 : WMMA_REGS {
7400 // NVPTX register types used to carry fragment data.
7401 NVPTXRegClass regclass = !cond(
7402 !eq(PtxEltType, "f16") : Float16x2Regs,
7403 !eq(PtxEltType, "f32") : Float32Regs);
7404
7405 // Instruction input/output arguments for the fragment.
7406 list ptx_regs = !foreach(tmp, regs, regclass);
7407
7408 // List of register names for the fragment -- ["ra0", "ra1",...]
7409 list reg_names = RegSeq.ret;
7410 // Generates "{{$r0, $r1,.... $rN-1}}" for use in asm string construction.
7411 string regstring = "{{$" # !head(reg_names)
7412 # !foldl("", !tail(reg_names), a, b,
7413 !strconcat(a, ", $", b))
7414 # "}}";
7415
7416 // Predicates for particular fragment variant. Technically those are
7417 // per-instruction predicates, but currently all fragments that can be used in
7418 // a given instruction are subject to the same constraints, so an instruction
7419 // can use predicates from any of its fragments. If/when this is no
7420 // longer the case, we can concat all per-fragment predicates to enforce that
7421 // all fragments of the instruction are viable.
7422 list Predicates = !cond(
7423 // fp16 -> fp16/fp32 @ m16n16k16
7424 !and(!eq(Geom, "m16n16k16"),
7425 !or(!eq(PtxEltType, "f16"),
7426 !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60],
7427
7428 // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
7429 !and(!or(!eq(Geom, "m8n32k16"),
7430 !eq(Geom, "m32n8k16")),
7431 !or(!eq(PtxEltType, "f16"),
7432 !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]);
7433
7434 // template DAGs for instruction inputs/output.
7435 dag Outs = !dag(outs, ptx_regs, reg_names);
7436 dag Ins = !dag(ins, ptx_regs, reg_names);
7437 }
7438
7439 class BuildPattern {
74127440 // Build a dag pattern that matches the intrinsic call.
74137441 // We want a dag that looks like this:
74147442 // (set , (intrinsic )) where input and
74297457 !subst(ins, IntrMatcher, tmp)))));
74307458 // Finally, consatenate both parts together. !con() requires both dags to have
74317459 // the same operator, so we wrap PatArgs in a (set ...) dag.
7432 let Pattern = [!con(PatOuts, (set PatArgs))];
7433 let OutOperandList = Outs;
7460 dag ret = !con(PatOuts, (set PatArgs));
7461 }
7462
7463 //
7464 // wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
7465 //
7466
7467 class WMMA_LOAD_INTR_HELPER
7468 bit WithStride>
7469 : PatFrag <(ops),(ops)> {
7470 // Intrinsic that matches this instruction.
7471 Intrinsic Intr = !cast(WMMA_NAME_LDST<"load", Frag, Layout,
7472 WithStride>.record);
7473 let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
7474 let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
7475 let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared,
7476 !eq(Space, ".global"): AS_match.global,
7477 1: AS_match.generic);
7478 }
7479
7480 class WMMA_LOAD
7481 DAGOperand SrcOp>
7482 : EmptyNVPTXInst,
7483 Requires {
7484 // Pattern that matches the intrinsic for this instruction variant.
7485 PatFrag IntrMatcher = WMMA_LOAD_INTR_HELPER;
7486 dag Ins = !con((ins SrcOp:$src), !if(WithStride, (ins Int32Regs:$ldm), (ins)));
7487
7488 let Pattern = [BuildPattern.ret];
7489 let OutOperandList = Frag.Outs;
74347490 let InOperandList = Ins;
74357491 let AsmString = "wmma.load."
7436 # Abc
7492 # Frag.frag
74377493 # ".sync"
74387494 # "." # Layout
7439 # "." # Geometry
7495 # "." # Frag.geom
74407496 # Space
7441 # "." # Type # " \t"
7442 # !if(!eq(Abc#Type, "cf16"),
7443 "{{$r0, $r1, $r2, $r3}}",
7444 "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
7497 # "." # Frag.ptx_elt_type # " \t"
7498 # Frag.regstring
74457499 # ", [$src]"
74467500 # !if(WithStride, ", $ldm", "")
74477501 # ";";
74487502 }
74497503
7450 class WMMA_LOAD_INTR_HELPER
7451 string Space, string Type, bit WithStride>
7452 : PatFrag <(ops),(ops)> {
7453 // Intrinsic that matches this instruction.
7454 Intrinsic Intr = !cast("int_nvvm_wmma"
7455 # "_" # Geometry # "_load_"
7456 # Abc # "_" # Type # "_" # Layout
7457 # !if(WithStride,"_stride", ""));
7458 code match_generic = [{
7459 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
7460 }];
7461 code match_shared = [{
7462 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
7463 }];
7464 code match_global = [{
7465 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
7466 }];
7467
7468 let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
7469 let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
7470 let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
7471 !if(!eq(Space, ".global"), match_global, match_generic));
7472 }
7473
7474 multiclass WMMA_LOAD_GALSTS
7475 string Space, string Type, NVPTXRegClass regclass,
7476 bit WithStride> {
7477 def _avar: WMMA_LOAD_GALSTOS
7478 imem, WithStride>;
7479 def _areg: WMMA_LOAD_GALSTOS
7480 Int32Regs, WithStride>;
7481 def _areg64: WMMA_LOAD_GALSTOS
7482 Int64Regs, WithStride>;
7483 def _ari: WMMA_LOAD_GALSTOS
7484 MEMri, WithStride>;
7485 def _ari64: WMMA_LOAD_GALSTOS
7486 MEMri64, WithStride>;
7487 }
7488
7489 multiclass WMMA_LOAD_GALSTSh
7490 string Space, string Type, NVPTXRegClass regclass,
7491 bit WithStride> {
7492 // Define a PatFrag that matches appropriate intrinsic that loads from the
7493 // given address space.
7494 def _Intr: WMMA_LOAD_INTR_HELPER
7495 WithStride>;
7496 defm NAME: WMMA_LOAD_GALSTS
7497 WithStride>;
7498 }
7499
7500 multiclass WMMA_LOAD_GALST
7501 string Space, string Type, NVPTXRegClass regclass> {
7502 defm _stride: WMMA_LOAD_GALSTSh;
7503 defm NAME: WMMA_LOAD_GALSTSh;
7504 }
7505
7506 multiclass WMMA_LOAD_GALT
7507 string Type, NVPTXRegClass regclass> {
7508 defm _global: WMMA_LOAD_GALST
7509 Type, regclass>;
7510 defm _shared: WMMA_LOAD_GALST
7511 Type, regclass>;
7512 defm NAME: WMMA_LOAD_GALST
7513 Type, regclass>;
7514 }
7515
7516 multiclass WMMA_LOAD_GAT
7517 string Type, NVPTXRegClass regclass> {
7518 defm _row: WMMA_LOAD_GALT;
7519 defm _col: WMMA_LOAD_GALT;
7520 }
7521
7522 multiclass WMMA_LOAD_G {
7523 defm _load_a: WMMA_LOAD_GAT;
7524 defm _load_b: WMMA_LOAD_GAT;
7525 defm _load_c_f16: WMMA_LOAD_GAT;
7526 defm _load_c_f32: WMMA_LOAD_GAT;
7527 }
7528
7529 defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">;
7530 defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">;
7531 defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">;
7532
75337504 //
75347505 // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
75357506 //
7536 class WMMA_STORE_D_GLSTSO
7537 string Type, NVPTXRegClass regclass,
7538 bit WithStride, DAGOperand DstOp>
7507 class WMMA_STORE_INTR_HELPER
7508 bit WithStride>
7509 : PatFrag <(ops),(ops)> {
7510 // Intrinsic that matches this instruction.
7511 Intrinsic Intr = !cast(WMMA_NAME_LDST<"store", Frag, Layout,
7512 WithStride>.record);
7513 let Operands = !con((ops node:$dst),
7514 !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names),
7515 !if(WithStride, (ops node:$ldm), (ops)));
7516 let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
7517 let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared,
7518 !eq(Space, ".global"): AS_match.global,
7519 1: AS_match.generic);
7520 }
7521
7522 class WMMA_STORE
7523 DAGOperand DstOp>
75397524 : EmptyNVPTXInst,
7540 Requires<[!if(!eq(Geometry, "m16n16k16"),
7541 hasPTX60,
7542 hasPTX61),
7543 hasSM70]> {
7544 PatFrag IntrMatcher = !cast("INT_WMMA"
7545 # "_" # Geometry # "_store_d"
7546 # "_" # Type
7547 # "_" # Layout
7548 # !subst(".", "_", Space)
7549 # !if(WithStride,"_stride", "")
7550 # "_Intr");
7551 dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1,
7552 regclass:$r2, regclass:$r3);
7553 dag InsR47 = (ins regclass:$r4, regclass:$r5,
7554 regclass:$r6, regclass:$r7);
7555 dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47));
7556 dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
7557 dag Ins = !con(InsR, StrideArg);
7558
7559 // Construct the pattern to match corresponding intrinsic call. See the
7560 // details in the comments in WMMA_LOAD_ALSTOS.
7561 dag PatArgs = !foreach(tmp, Ins,
7562 !subst(imem, ADDRvar,
7563 !subst(MEMri64, ADDRri64,
7564 !subst(MEMri, ADDRri,
7565 !subst(ins, IntrMatcher, tmp)))));
7566 let Pattern = [PatArgs];
7525 Requires {
7526 PatFrag IntrMatcher = WMMA_STORE_INTR_HELPER;
7527 dag Ins = !con((ins DstOp:$src),
7528 Frag.Ins,
7529 !if(WithStride, (ins Int32Regs:$ldm), (ins)));
7530 let Pattern = [BuildPattern<(set), IntrMatcher, Ins>.ret];
75677531 let OutOperandList = (outs);
75687532 let InOperandList = Ins;
75697533 let AsmString = "wmma.store.d.sync."
75707534 # Layout
7571 # "." # Geometry
7535 # "." # Frag.geom
75727536 # Space
7573 # "." # Type
7537 # "." # Frag.ptx_elt_type
75747538 # " \t[$src],"
7575 # !if(!eq(Type,"f16"),
7576 "{{$r0, $r1, $r2, $r3}}",
7577 "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
7539 # Frag.regstring
75787540 # !if(WithStride, ", $ldm", "")
75797541 # ";";
7580
75817542 }
75827543
7583 class WMMA_STORE_INTR_HELPER
7584 string Type, bit WithStride>
7585 : PatFrag <(ops),(ops)> {
7586 // Intrinsic that matches this instruction.
7587 Intrinsic Intr = !cast("int_nvvm_wmma_"
7588 # Geometry
7589 # "_store_d"
7590 # "_" # Type
7591 # "_" # Layout
7592 # !if(WithStride, "_stride", ""));
7593 code match_generic = [{
7594 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
7595 }];
7596 code match_shared = [{
7597 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
7598 }];
7599 code match_global = [{
7600 return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
7601 }];
7602
7603 dag Args = !if(!eq(Type,"f16"),
7604 (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3),
7605 (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3,
7606 node:$r4, node:$r5, node:$r6, node:$r7));
7607 dag StrideArg = !if(WithStride, (ops node:$ldm), (ops));
7608 let Operands = !con(Args, StrideArg);
7609 let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
7610 let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
7611 !if(!eq(Space, ".global"), match_global, match_generic));
7612 }
7613
7614 multiclass WMMA_STORE_D_GLSTS
7615 string Type, NVPTXRegClass regclass,
7616 bit WithStride> {
7617 def _avar: WMMA_STORE_D_GLSTSO
7618 WithStride, imem>;
7619 def _areg: WMMA_STORE_D_GLSTSO
7620 WithStride, Int32Regs>;
7621 def _areg64: WMMA_STORE_D_GLSTSO
7622 WithStride, Int64Regs>;
7623 def _ari: WMMA_STORE_D_GLSTSO
7624 WithStride, MEMri>;
7625 def _ari64: WMMA_STORE_D_GLSTSO
7626 WithStride, MEMri64>;
7627 }
7628
7629 multiclass WMMA_STORE_D_GLSTSh
7630 string Type, NVPTXRegClass regclass,
7631 bit WithStride> {
7632 // Define a PatFrag that matches appropriate intrinsic that loads from the
7633 // given address space.
7634 def _Intr: WMMA_STORE_INTR_HELPER
7635 WithStride>;
7636 defm NAME: WMMA_STORE_D_GLSTS
7637 WithStride>;
7638 }
7639
7640 multiclass WMMA_STORE_D_GLST
7641 string Type, NVPTXRegClass regclass > {
7642 defm _stride: WMMA_STORE_D_GLSTSh;
7643 defm NAME: WMMA_STORE_D_GLSTSh;
7644 }
7645
7646 multiclass WMMA_STORE_D_GLT
7647 string Type, NVPTXRegClass regclass> {
7648 defm _global: WMMA_STORE_D_GLST;
7649 defm _shared: WMMA_STORE_D_GLST;
7650 defm NAME: WMMA_STORE_D_GLST;
7651 }
7652
7653 multiclass WMMA_STORE_D_GT
7654 NVPTXRegClass regclass> {
7655 defm _row: WMMA_STORE_D_GLT;
7656 defm _col: WMMA_STORE_D_GLT;
7657 }
7658
7659 multiclass WMMA_STORE_D_G {
7660 defm _store_d_f16: WMMA_STORE_D_GT;
7661 defm _store_d_f32: WMMA_STORE_D_GT;
7662 }
7663
7664 defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">;
7665 defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
7666 defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">;
7544 // Create all load/store variants
7545 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
7546 foreach layout = ["row", "col"] in {
7547 foreach stride = [0, 1] in {
7548 foreach space = [".global", ".shared", ""] in {
7549 foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
7550 foreach frag = [WMMA_REGINFO,
7551 WMMA_REGINFO,
7552 WMMA_REGINFO,
7553 WMMA_REGINFO] in {
7554 def : WMMA_LOAD;
7555 }
7556 foreach frag = [WMMA_REGINFO,
7557 WMMA_REGINFO] in {
7558 def : WMMA_STORE;
7559 }
7560 } // addr
7561 } // space
7562 } // stride
7563 } // layout
7564 } // geom
76677565
76687566 // WMMA.MMA
7669 class WMMA_MMA_GABDCS
7670 string DType, NVPTXRegClass d_reg,
7671 string CType, NVPTXRegClass c_reg,
7672 NVPTXRegClass ab_reg,
7673 string Satfinite = "">
7567 class WMMA_MMA
7568 WMMA_REGINFO FragC, WMMA_REGINFO FragD,
7569 string ALayout, string BLayout, int Satfinite>
76747570 : EmptyNVPTXInst,
7675 Requires<[!if(!eq(Geometry, "m16n16k16"),
7676 hasPTX60,
7677 hasPTX61),
7678 hasSM70]> {
7679 Intrinsic Intr = !cast("int_nvvm_wmma_"
7680 # Geometry
7681 # "_mma"
7682 # "_" # ALayout
7683 # "_" # BLayout
7684 # "_" # DType
7685 # "_" # CType
7686 # !subst(".", "_", Satfinite));
7687 dag Outs = !if(!eq(DType,"f16"),
7688 (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
7689 (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
7690 d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7));
7691 dag InsExtraCArgs = !if(!eq(CType,"f16"),
7692 (ins),
7693 (ins c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7));
7694 dag Ins = !con((ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
7695 ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
7696 ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
7697 ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
7698 c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3),
7699 InsExtraCArgs);
7700
7701 // Construct the pattern to match corresponding intrinsic call. See the
7702 // details in the comments in WMMA_LOAD_ALSTOS.
7571 Requires {
7572 //Intrinsic Intr = int_nvvm_suld_1d_v4i32_zero;
7573 Intrinsic Intr = !cast(WMMA_NAME_MMA.record);
7574 dag Outs = FragD.Outs;
7575 dag Ins = !con(FragA.Ins,
7576 FragB.Ins,
7577 FragC.Ins);
7578
7579 // Construct the pattern to match corresponding intrinsic call.
7580 // mma does not load/store anything, so we don't need complex operand matching here.
77037581 dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp));
77047582 dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp));
77057583 let Pattern = [!con(PatOuts, (set PatArgs))];
77087586 let AsmString = "wmma.mma.sync."
77097587 # ALayout
77107588 # "." # BLayout
7711 # "." # Geometry
7712 # "." # DType
7713 # "." # CType
7714 # Satfinite # "\n\t\t"
7715 # !if(!eq(DType,"f16"),
7716 "{{$d0, $d1, $d2, $d3}}, \n\t\t",
7717 "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
7718 # "{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
7719 # "{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
7720 # !if(!eq(CType,"f16"),
7721 "{{$c0, $c1, $c2, $c3}};",
7722 "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};");
7589 # "." # FragA.geom
7590 # "." # FragD.ptx_elt_type
7591 # "." # FragC.ptx_elt_type
7592 # !if(Satfinite, ".satfinite", "") # "\n\t\t"
7593 # FragD.regstring # ",\n\t\t"
7594 # FragA.regstring # ",\n\t\t"
7595 # FragB.regstring # ",\n\t\t"
7596 # FragC.regstring # ";";
77237597 }
77247598
7725 multiclass WMMA_MMA_GABDC
7726 string DType, NVPTXRegClass d_reg,
7727 string CType, NVPTXRegClass c_reg> {
7728 def _satfinite: WMMA_MMA_GABDCS
7729 DType, d_reg, CType, c_reg,
7730 Float16x2Regs, ".satfinite">;
7731 def NAME: WMMA_MMA_GABDCS
7732 DType, d_reg, CType, c_reg,
7733 Float16x2Regs>;
7734 }
7735
7736 multiclass WMMA_MMA_GABD
7737 string DType, NVPTXRegClass d_reg> {
7738 defm _f16: WMMA_MMA_GABDC
7739 "f16", Float16x2Regs>;
7740 defm _f32: WMMA_MMA_GABDC
7741 "f32", Float32Regs>;
7742 }
7743
7744 multiclass WMMA_MMA_GAB {
7745 defm _f16: WMMA_MMA_GABD;
7746 defm _f32: WMMA_MMA_GABD;
7747 }
7748
7749 multiclass WMMA_MMA_GA {
7750 defm _col: WMMA_MMA_GAB;
7751 defm _row: WMMA_MMA_GAB;
7752 }
7753
7754 multiclass WMMA_MMA_G {
7755 defm _col: WMMA_MMA_GA;
7756 defm _row: WMMA_MMA_GA;
7757 }
7758
7759 defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">;
7760 defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">;
7761 defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">;
7599 foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
7600 foreach layout_a = ["row", "col"] in {
7601 foreach layout_b = ["row", "col"] in {
7602 foreach frag_c = [WMMA_REGINFO,
7603 WMMA_REGINFO] in {
7604 foreach frag_d = [WMMA_REGINFO,
7605 WMMA_REGINFO] in {
7606 foreach satf = [0, 1] in {
7607 def : WMMA_MMA,
7608 WMMA_REGINFO,
7609 frag_c, frag_d, layout_a, layout_b, satf>;
7610 } // satf
7611 } // frag_d
7612 } // frag_c
7613 } // layout_b
7614 } // layout_a
7615 } // geom