llvm.org GIT mirror llvm / release_70 test / CodeGen / NVPTX / wmma.py
release_70

Tree @release_70 (Download .tar.gz)

wmma.py @release_70

a187b48
 
 
 
cbae883
a187b48
 
 
 
 
 
 
 
 
 
 
 
9a1bbfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a187b48
 
 
 
 
cbae883
 
a187b48
 
bc1c5ed
a187b48
bc1c5ed
 
cbae883
a187b48
 
bc1c5ed
a187b48
 
 
bc1c5ed
 
cbae883
a187b48
 
9a1bbfc
bc1c5ed
a187b48
 
 
bc1c5ed
cbae883
a187b48
cbae883
 
a187b48
 
 
 
 
 
 
 
 
 
 
9a1bbfc
 
bc1c5ed
cbae883
a187b48
 
 
 
 
 
bc1c5ed
 
 
a187b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc1c5ed
a187b48
bc1c5ed
 
cbae883
a187b48
 
bc1c5ed
a187b48
 
 
bc1c5ed
 
cbae883
a187b48
 
9a1bbfc
bc1c5ed
a187b48
 
 
bc1c5ed
cbae883
a187b48
cbae883
 
a187b48
 
 
 
 
 
 
 
 
 
 
9a1bbfc
 
bc1c5ed
cbae883
a187b48
 
 
bc1c5ed
 
 
a187b48
 
 
 
 
 
 
 
 
 
 
 
 
 
bc1c5ed
a187b48
 
bc1c5ed
 
a187b48
cbae883
 
 
 
 
bc1c5ed
a187b48
 
 
 
bc1c5ed
 
a187b48
cbae883
 
a187b48
 
 
 
 
 
 
 
 
 
 
bc1c5ed
cbae883
a187b48
 
 
bc1c5ed
 
 
a187b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
# This test generates all variants of wmma intrinsics and verifies that LLVM
# generates correct instructions for them.

# RUN: python %s > %t.ll
# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll

from itertools import product
from string import Template

def make_wmma_slice_ty(abcd, itype):
  elt_ty = "<2 x half>" if itype == "f16" else "float"
  num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
  return [elt_ty] * num_elts

def make_wmma_ld_ret_ty(abc, itype):
  return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))

# returns address space
def get_aspace(space):
  space_map = {
      ".global" : 1,
      ".shared" : 3,
      ".const"  : 4,
      ".local"  : 5,
      ".param"  : 101,
      ""        : 0,
      ".generic": 0
  }
  return space_map[space];

def get_pspace(space):
  return "p%di8" % get_aspace(space);

# Convenient test patterns.
check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)

known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]

def gen_wmma_load_tests():
  load_template = """
declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});

; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
  ret ${ret_ty} %v0;
}

; CHECK-LABEL: .func{{.*}}test_${function}_o(
define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
  ret ${ret_ty} %v0;
}
"""
  intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
  instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}"

  for geom, abc, layout, space, stride, itype in product(
      known_geoms,
      "abc",
      ["row","col"],
      ["",".shared",".global"],
      ["", ".stride"],
      ["f16", "f32"]):

    params = {
        "abc" : abc,
        "layout" : layout,
        "space" : space,
        "stride" : stride,
        "itype" : itype,
        "pspace" : get_pspace(space),
        "as"     : "addrspace(%d)" % get_aspace(space),
        "geom"   : geom,
    }

    if itype == "f32" and abc != "c":
      continue

    test_params = params
    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
    test_params["function"] = test_params["intrinsic"].replace(".","_")
    test_params["instruction"] = Template(instruction_template).substitute(params)
    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
    if abc == "c" :
      test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
    else:
      test_params["check_result"] = check_f16_8

    if stride:
      test_params["extra_args"] = ", i32 %stride";
      test_params["stride_pattern"] = ", %r{{[0-9]+}}"
    else:
      test_params["extra_args"] = ""
      test_params["stride_pattern"] = ""

    print(Template(load_template).substitute(test_params))

def make_wmma_slice_args(itype, abcd, prefix="v"):
  return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
                  in enumerate(make_wmma_slice_ty(abcd, itype))])

def gen_wmma_store_tests():
  store_template = """
declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});

; CHECK-LABEL: .func {{.*}}test_${function}(
define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
; CHECK: {${check_args}}
; CHECK: ${stride_pattern}
  call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
  ret void
}

; CHECK-LABEL: .func{{.*}}test_${function}_o(
define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
; CHECK: ${check_args}
; CHECK: ${stride_pattern}
  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
  call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
  ret void
}
"""
  intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
  instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}"

  for geom, abc, layout, space, stride, itype in product(
      known_geoms,
      "d",
      ["row","col"],
      ["",".shared",".global"],
      ["", ".stride"],
      ["f16", "f32"]):

    params = {
        "abc" : abc,
        "layout" : layout,
        "space" : space,
        "stride" : stride,
        "itype" : itype,
        "pspace" : get_pspace(space),
        "as"     : "addrspace(%d)" % get_aspace(space),
        "geom"   : geom,
    }

    test_params = params
    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
    test_params["function"] = test_params["intrinsic"].replace(".","_")
    test_params["instruction"] = Template(instruction_template).substitute(params)
    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
    test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
    if stride:
      test_params["extra_args"] = ", i32 %stride";
      test_params["stride_pattern"] = ", %r{{[0-9]+}};"
    else:
      test_params["extra_args"] = ""
      test_params["stride_pattern"] = ";"
    test_params["args"] = make_wmma_slice_args(itype, "d");

    print(Template(store_template).substitute(test_params))

def gen_wmma_mma_tests():
  mma_template = """
declare ${ret_ty} @${intrinsic}(
        ${args});

; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(
        ${args}) {
; CHECK: ${instruction}
; CHECK-NEXT: ${check_d}
; CHECK-NEXT: ${check_ab}
; CHECK-NEXT: ${check_ab}
; CHECK-NEXT: ${check_c}
  %r = call ${ret_ty} @${intrinsic}(
        ${args});
  ret ${ret_ty} %r;
}
"""
  intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
  instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"

  for geom, alayout, blayout, ctype, dtype, satf in product(
      known_geoms,
      ["row","col"],
      ["row","col"],
      ["f16", "f32"],
      ["f16", "f32"],
      [".satfinite", ""]):

    params = {
        "alayout" : alayout,
        "blayout" : blayout,
        "ctype" : ctype,
        "dtype" : dtype,
        "satf"  : satf,
        "geom"  : geom,
    }

    test_params = params
    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
    test_params["function"] = test_params["intrinsic"].replace(".", "_")
    test_params["instruction"] = Template(instruction_template).substitute(params)
    test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
    test_params["check_ab"] = check_f16_8
    test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
    test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
    args = ",\n        ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
                              for abcd, t in (("a", "f16"),
                                              ("b", "f16"),
                                              ("c", ctype)))
    test_params["args"] = args
    print(Template(mma_template).substitute(test_params))

def main():
  gen_wmma_load_tests()
  gen_wmma_store_tests()
  gen_wmma_mma_tests()

main()