llvm.org GIT mirror llvm / 8992274
[NVPTX] Clean up argument lowering code and properly handle alignment for structs and vectors git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211938 91177308-0d34-0410-b5e6-96231b3b80d8 Justin Holewinski 6 years ago
2 changed file(s) with 90 addition(s) and 91 deletion(s). Raw diff Collapse all Expand all
6464 case MVT::v2f64:
6565 return true;
6666 }
67 }
68
69 static uint64_t GCD( int a, int b)
70 {
71 if (a < b) std::swap(a,b);
72 while (b > 0) {
73 uint64_t c = b;
74 b = a % b;
75 a = c;
76 }
77 return a;
6778 }
6879
6980 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
517528 } else if (isa(retTy)) {
518529 O << ".param .b" << getPointerTy().getSizeInBits() << " _";
519530 } else {
520 if ((retTy->getTypeID() == Type::StructTyID) || isa(retTy)) {
521 SmallVector vtparts;
522 ComputeValueVTs(*this, retTy, vtparts);
523 unsigned totalsz = 0;
524 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
525 unsigned elems = 1;
526 EVT elemtype = vtparts[i];
527 if (vtparts[i].isVector()) {
528 elems = vtparts[i].getVectorNumElements();
529 elemtype = vtparts[i].getVectorElementType();
530 }
531 // TODO: no need to loop
532 for (unsigned j = 0, je = elems; j != je; ++j) {
533 unsigned sz = elemtype.getSizeInBits();
534 if (elemtype.isInteger() && (sz < 8))
535 sz = 8;
536 totalsz += sz / 8;
537 }
538 }
539 O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
531 if((retTy->getTypeID() == Type::StructTyID) ||
532 isa(retTy)) {
533 O << ".param .align "
534 << retAlignment
535 << " .b8 _["
536 << getDataLayout()->getTypeAllocSize(retTy) << "]";
540537 } else {
541538 assert(false && "Unknown return type");
542539 }
705702 if (Ty->isAggregateType()) {
706703 // aggregate
707704 SmallVector vtparts;
708 ComputeValueVTs(*this, Ty, vtparts);
705 SmallVector Offsets;
706 ComputePTXValueVTs(*this, Ty, vtparts, &Offsets, 0);
709707
710708 unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
711709 // declare .param .align .b8 .param[];
717715 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
718716 DeclareParamOps);
719717 InFlag = Chain.getValue(1);
720 unsigned curOffset = 0;
721718 for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
722 unsigned elems = 1;
723719 EVT elemtype = vtparts[j];
724 if (vtparts[j].isVector()) {
725 elems = vtparts[j].getVectorNumElements();
726 elemtype = vtparts[j].getVectorElementType();
720 unsigned ArgAlign = GCD(align, Offsets[j]);
721 if (elemtype.isInteger() && (sz < 8))
722 sz = 8;
723 SDValue StVal = OutVals[OIdx];
724 if (elemtype.getSizeInBits() < 16) {
725 StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
727726 }
728 for (unsigned k = 0, ke = elems; k != ke; ++k) {
729 unsigned sz = elemtype.getSizeInBits();
730 if (elemtype.isInteger() && (sz < 8))
731 sz = 8;
732 SDValue StVal = OutVals[OIdx];
733 if (elemtype.getSizeInBits() < 16) {
734 StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
735 }
736 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
737 SDValue CopyParamOps[] = { Chain,
738 DAG.getConstant(paramCount, MVT::i32),
739 DAG.getConstant(curOffset, MVT::i32),
740 StVal, InFlag };
741 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
742 CopyParamVTs, CopyParamOps,
743 elemtype, MachinePointerInfo());
744 InFlag = Chain.getValue(1);
745 curOffset += sz / 8;
746 ++OIdx;
747 }
727 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
728 SDValue CopyParamOps[] = { Chain,
729 DAG.getConstant(paramCount, MVT::i32),
730 DAG.getConstant(Offsets[j], MVT::i32),
731 StVal, InFlag };
732 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
733 CopyParamVTs, CopyParamOps,
734 elemtype, MachinePointerInfo(),
735 ArgAlign);
736 InFlag = Chain.getValue(1);
737 ++OIdx;
748738 }
749739 if (vtparts.size() > 0)
750740 --OIdx;
929919 }
930920 // struct or vector
931921 SmallVector vtparts;
922 SmallVector Offsets;
932923 const PointerType *PTy = dyn_cast(Args[i].Ty);
933924 assert(PTy && "Type of a byval parameter should be pointer");
934 ComputeValueVTs(*this, PTy->getElementType(), vtparts);
925 ComputePTXValueVTs(*this, PTy->getElementType(), vtparts, &Offsets, 0);
935926
936927 // declare .param .align .b8 .param[];
937928 unsigned sz = Outs[OIdx].Flags.getByValSize();
938929 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
930 unsigned ArgAlign = Outs[OIdx].Flags.getByValAlign();
939931 // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
940932 // so we don't need to worry about natural alignment or not.
941933 // See TargetLowering::LowerCallTo().
947939 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
948940 DeclareParamOps);
949941 InFlag = Chain.getValue(1);
950 unsigned curOffset = 0;
951942 for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
952 unsigned elems = 1;
953943 EVT elemtype = vtparts[j];
954 if (vtparts[j].isVector()) {
955 elems = vtparts[j].getVectorNumElements();
956 elemtype = vtparts[j].getVectorElementType();
944 int curOffset = Offsets[j];
945 unsigned PartAlign = GCD(ArgAlign, curOffset);
946 SDValue srcAddr =
947 DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
948 DAG.getConstant(curOffset, getPointerTy()));
949 SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
950 MachinePointerInfo(), false, false, false,
951 PartAlign);
952 if (elemtype.getSizeInBits() < 16) {
953 theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
957954 }
958 for (unsigned k = 0, ke = elems; k != ke; ++k) {
959 unsigned sz = elemtype.getSizeInBits();
960 if (elemtype.isInteger() && (sz < 8))
961 sz = 8;
962 SDValue srcAddr =
963 DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
964 DAG.getConstant(curOffset, getPointerTy()));
965 SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
966 MachinePointerInfo(), false, false, false,
967 0);
968 if (elemtype.getSizeInBits() < 16) {
969 theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
970 }
971 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
972 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
973 DAG.getConstant(curOffset, MVT::i32), theVal,
974 InFlag };
975 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
976 CopyParamOps, elemtype,
977 MachinePointerInfo());
978
979 InFlag = Chain.getValue(1);
980 curOffset += sz / 8;
981 }
955 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
956 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
957 DAG.getConstant(curOffset, MVT::i32), theVal,
958 InFlag };
959 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
960 CopyParamOps, elemtype,
961 MachinePointerInfo());
962
963 InFlag = Chain.getValue(1);
982964 }
983965 ++paramCount;
984966 }
10871069
10881070 // Generate loads from param memory/moves from registers for result
10891071 if (Ins.size() > 0) {
1090 unsigned resoffset = 0;
10911072 if (retTy && retTy->isVectorTy()) {
10921073 EVT ObjectVT = getValueType(retTy);
10931074 unsigned NumElts = ObjectVT.getVectorNumElements();
10961077 ObjectVT) == NumElts &&
10971078 "Vector was not scalarized");
10981079 unsigned sz = EltVT.getSizeInBits();
1099 bool needTruncate = sz < 16 ? true : false;
1080 bool needTruncate = sz < 8 ? true : false;
11001081
11011082 if (NumElts == 1) {
11021083 // Just a simple load
11031084 SmallVector LoadRetVTs;
1104 if (needTruncate) {
1105 // If loading i1 result, generate
1106 // load i16
1085 if (EltVT == MVT::i1 || EltVT == MVT::i8) {
1086 // If loading i1/i8 result, generate
1087 // load.b8 i16
1088 // if i1
11071089 // trunc i16 to i1
11081090 LoadRetVTs.push_back(MVT::i16);
11091091 } else
11271109 } else if (NumElts == 2) {
11281110 // LoadV2
11291111 SmallVector LoadRetVTs;
1130 if (needTruncate) {
1131 // If loading i1 result, generate
1132 // load i16
1112 if (EltVT == MVT::i1 || EltVT == MVT::i8) {
1113 // If loading i1/i8 result, generate
1114 // load.b8 i16
1115 // if i1
11331116 // trunc i16 to i1
11341117 LoadRetVTs.push_back(MVT::i16);
11351118 LoadRetVTs.push_back(MVT::i16);
11721155 EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
11731156 for (unsigned i = 0; i < NumElts; i += VecSize) {
11741157 SmallVector LoadRetVTs;
1175 if (needTruncate) {
1176 // If loading i1 result, generate
1177 // load i16
1158 if (EltVT == MVT::i1 || EltVT == MVT::i8) {
1159 // If loading i1/i8 result, generate
1160 // load.b8 i16
1161 // if i1
11781162 // trunc i16 to i1
11791163 for (unsigned j = 0; j < VecSize; ++j)
11801164 LoadRetVTs.push_back(MVT::i16);
12131197 }
12141198 } else {
12151199 SmallVector VTs;
1216 ComputePTXValueVTs(*this, retTy, VTs);
1200 SmallVector Offsets;
1201 ComputePTXValueVTs(*this, retTy, VTs, &Offsets, 0);
12171202 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1203 unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy, 0);
12181204 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
12191205 unsigned sz = VTs[i].getSizeInBits();
1206 unsigned AlignI = GCD(RetAlign, Offsets[i]);
12201207 bool needTruncate = sz < 8 ? true : false;
12211208 if (VTs[i].isInteger() && (sz < 8))
12221209 sz = 8;
12421229 SmallVector LoadRetOps;
12431230 LoadRetOps.push_back(Chain);
12441231 LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1245 LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
1232 LoadRetOps.push_back(DAG.getConstant(Offsets[i], MVT::i32));
12461233 LoadRetOps.push_back(InFlag);
12471234 SDValue retval = DAG.getMemIntrinsicNode(
12481235 NVPTXISD::LoadParam, dl,
12491236 DAG.getVTList(LoadRetVTs), LoadRetOps,
1250 TheLoadType, MachinePointerInfo());
1237 TheLoadType, MachinePointerInfo(), AlignI);
12511238 Chain = retval.getValue(1);
12521239 InFlag = retval.getValue(2);
12531240 SDValue Ret0 = retval.getValue(0);
12541241 if (needTruncate)
12551242 Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
12561243 InVals.push_back(Ret0);
1257 resoffset += sz / 8;
12581244 }
12591245 }
12601246 }
0 ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
1
2 ; CHECK: .visible .func (.param .align 16 .b8 func_retval0[16]) foo0(
3 ; CHECK: .param .align 4 .b8 foo0_param_0[8]
4 define <4 x float> @foo0({float, float} %arg0) {
5 ret <4 x float>
6 }
7
8 ; CHECK: .visible .func (.param .align 8 .b8 func_retval0[8]) foo1(
9 ; CHECK: .param .align 8 .b8 foo1_param_0[16]
10 define <2 x float> @foo1({float, float, i64} %arg0) {
11 ret <2 x float>
12 }