llvm.org GIT mirror llvm / 6e8b53d
Masked Load/Store - fixed a bug in type legalization. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@225441 91177308-0d34-0410-b5e6-96231b3b80d8 Elena Demikhovsky 5 years ago
4 changed file(s) with 156 addition(s) and 3 deletion(s). Raw diff Collapse all Expand all
6565 case ISD::EXTRACT_VECTOR_ELT:
6666 Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break;
6767 case ISD::LOAD: Res = PromoteIntRes_LOAD(cast(N));break;
68 case ISD::MLOAD: Res = PromoteIntRes_MLOAD(cast(N));break;
6869 case ISD::SELECT: Res = PromoteIntRes_SELECT(N); break;
6970 case ISD::VSELECT: Res = PromoteIntRes_VSELECT(N); break;
7071 case ISD::SELECT_CC: Res = PromoteIntRes_SELECT_CC(N); break;
453454 return Res;
454455 }
455456
457 SDValue DAGTypeLegalizer::PromoteIntRes_MLOAD(MaskedLoadSDNode *N) {
458 EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
459 SDValue ExtSrc0 = GetPromotedInteger(N->getSrc0());
460 SDValue ExtMask = PromoteTargetBoolean(N->getMask(), NVT);
461 SDLoc dl(N);
462
463 MachineMemOperand *MMO = DAG.getMachineFunction().
464 getMachineMemOperand(N->getPointerInfo(),
465 MachineMemOperand::MOLoad, NVT.getStoreSize(),
466 N->getAlignment(), N->getAAInfo(), N->getRanges());
467
468 SDValue Res = DAG.getMaskedLoad(NVT, dl, N->getChain(), N->getBasePtr(),
469 ExtMask, ExtSrc0, MMO);
470 // Legalized the chain result - switch anything that used the old chain to
471 // use the new one.
472 ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
473 return Res;
474 }
456475 /// Promote the overflow flag of an overflowing arithmetic node.
457476 SDValue DAGTypeLegalizer::PromoteIntRes_Overflow(SDNode *N) {
458477 // Simply change the return type of the boolean result.
10971116 SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo){
10981117
10991118 assert(OpNo == 2 && "Only know how to promote the mask!");
1100 EVT DataVT = N->getOperand(3).getValueType();
1101 SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT);
1119 SDValue DataOp = N->getData();
1120 EVT DataVT = DataOp.getValueType();
1121 SDValue Mask = N->getMask();
1122 EVT MaskVT = Mask.getValueType();
1123 SDLoc dl(N);
1124
1125 if (!TLI.isTypeLegal(DataVT)) {
1126 if (getTypeAction(DataVT) == TargetLowering::TypePromoteInteger) {
1127 DataOp = GetPromotedInteger(DataOp);
1128 Mask = PromoteTargetBoolean(Mask, DataOp.getValueType());
1129 }
1130 else {
1131 assert(getTypeAction(DataVT) == TargetLowering::TypeWidenVector &&
1132 "Unexpected data legalization in MSTORE");
1133 DataOp = GetWidenedVector(DataOp);
1134
1135 if (getTypeAction(MaskVT) == TargetLowering::TypeWidenVector)
1136 Mask = GetWidenedVector(Mask);
1137 else {
1138 EVT BoolVT = getSetCCResultType(DataOp.getValueType());
1139
1140 // We can't use ModifyToType() because we should fill the mask with
1141 // zeroes
1142 unsigned WidenNumElts = BoolVT.getVectorNumElements();
1143 unsigned MaskNumElts = MaskVT.getVectorNumElements();
1144
1145 unsigned NumConcat = WidenNumElts / MaskNumElts;
1146 SmallVector Ops(NumConcat);
1147 SDValue ZeroVal = DAG.getConstant(0, MaskVT);
1148 Ops[0] = Mask;
1149 for (unsigned i = 1; i != NumConcat; ++i)
1150 Ops[i] = ZeroVal;
1151
1152 Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, BoolVT, Ops);
1153 }
1154 }
1155 }
1156 else
1157 Mask = PromoteTargetBoolean(N->getMask(), DataOp.getValueType());
11021158 SmallVector NewOps(N->op_begin(), N->op_end());
1103 NewOps[OpNo] = Mask;
1159 NewOps[2] = Mask;
1160 NewOps[3] = DataOp;
11041161 return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
11051162 }
11061163
239239 SDValue PromoteIntRes_FP_TO_FP16(SDNode *N);
240240 SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
241241 SDValue PromoteIntRes_LOAD(LoadSDNode *N);
242 SDValue PromoteIntRes_MLOAD(MaskedLoadSDNode *N);
242243 SDValue PromoteIntRes_Overflow(SDNode *N);
243244 SDValue PromoteIntRes_SADDSUBO(SDNode *N, unsigned ResNo);
244245 SDValue PromoteIntRes_SDIV(SDNode *N);
630631 SDValue WidenVecRes_EXTRACT_SUBVECTOR(SDNode* N);
631632 SDValue WidenVecRes_INSERT_VECTOR_ELT(SDNode* N);
632633 SDValue WidenVecRes_LOAD(SDNode* N);
634 SDValue WidenVecRes_MLOAD(MaskedLoadSDNode* N);
633635 SDValue WidenVecRes_SCALAR_TO_VECTOR(SDNode* N);
634636 SDValue WidenVecRes_SIGN_EXTEND_INREG(SDNode* N);
635637 SDValue WidenVecRes_SELECT(SDNode* N);
17121712 case ISD::VECTOR_SHUFFLE:
17131713 Res = WidenVecRes_VECTOR_SHUFFLE(cast(N));
17141714 break;
1715 case ISD::MLOAD:
1716 Res = WidenVecRes_MLOAD(cast(N));
1717 break;
17151718
17161719 case ISD::ADD:
17171720 case ISD::AND:
24002403 ReplaceValueWith(SDValue(N, 1), NewChain);
24012404
24022405 return Result;
2406 }
2407
2408 SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
2409
2410 EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0));
2411 SDValue Mask = N->getMask();
2412 EVT MaskVT = Mask.getValueType();
2413 SDValue Src0 = GetWidenedVector(N->getSrc0());
2414 SDLoc dl(N);
2415
2416 if (getTypeAction(MaskVT) == TargetLowering::TypeWidenVector)
2417 Mask = GetWidenedVector(Mask);
2418 else {
2419 EVT BoolVT = getSetCCResultType(WidenVT);
2420
2421 // We can't use ModifyToType() because we should fill the mask with
2422 // zeroes
2423 unsigned WidenNumElts = BoolVT.getVectorNumElements();
2424 unsigned MaskNumElts = MaskVT.getVectorNumElements();
2425
2426 unsigned NumConcat = WidenNumElts / MaskNumElts;
2427 SmallVector Ops(NumConcat);
2428 SDValue ZeroVal = DAG.getConstant(0, MaskVT);
2429 Ops[0] = Mask;
2430 for (unsigned i = 1; i != NumConcat; ++i)
2431 Ops[i] = ZeroVal;
2432
2433 Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, BoolVT, Ops);
2434 }
2435
2436 // Rebuild memory operand because MemoryVT was changed
2437 MachineMemOperand *MMO = DAG.getMachineFunction().
2438 getMachineMemOperand(N->getPointerInfo(),
2439 MachineMemOperand::MOLoad, WidenVT.getStoreSize(),
2440 N->getAlignment(), N->getAAInfo(), N->getRanges());
2441
2442 SDValue Res = DAG.getMaskedLoad(WidenVT, dl, N->getChain(), N->getBasePtr(),
2443 Mask, Src0, MMO);
2444 // Legalized the chain result - switch anything that used the old chain to
2445 // use the new one.
2446 ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
2447 return Res;
24032448 }
24042449
24052450 SDValue DAGTypeLegalizer::WidenVecRes_SCALAR_TO_VECTOR(SDNode *N) {
148148 ret void
149149 }
150150
151 ; AVX2-LABEL: test14
152 ; AVX2: vshufps $-24
153 ; AVX2: vmaskmovps
154 define void @test14(<2 x i32> %trigger, <2 x float>* %addr, <2 x float> %val) {
155 %mask = icmp eq <2 x i32> %trigger, zeroinitializer
156 call void @llvm.masked.store.v2f32(<2 x float>%val, <2 x float>* %addr, i32 4, <2 x i1>%mask)
157 ret void
158 }
159
160 ; AVX2-LABEL: test15
161 ; AVX2: vpmaskmovq
162 define void @test15(<2 x i32> %trigger, <2 x i32>* %addr, <2 x i32> %val) {
163 %mask = icmp eq <2 x i32> %trigger, zeroinitializer
164 call void @llvm.masked.store.v2i32(<2 x i32>%val, <2 x i32>* %addr, i32 4, <2 x i1>%mask)
165 ret void
166 }
167
168 ; AVX2-LABEL: test16
169 ; AVX2: vmaskmovps
170 ; AVX2: vblendvps
171 define <2 x float> @test16(<2 x i32> %trigger, <2 x float>* %addr, <2 x float> %dst) {
172 %mask = icmp eq <2 x i32> %trigger, zeroinitializer
173 %res = call <2 x float> @llvm.masked.load.v2f32(<2 x float>* %addr, i32 4, <2 x i1>%mask, <2 x float>%dst)
174 ret <2 x float> %res
175 }
176
177 ; AVX2-LABEL: test17
178 ; AVX2: vpmaskmovq
179 ; AVX2: vblendvpd
180 define <2 x i32> @test17(<2 x i32> %trigger, <2 x i32>* %addr, <2 x i32> %dst) {
181 %mask = icmp eq <2 x i32> %trigger, zeroinitializer
182 %res = call <2 x i32> @llvm.masked.load.v2i32(<2 x i32>* %addr, i32 4, <2 x i1>%mask, <2 x i32>%dst)
183 ret <2 x i32> %res
184 }
185
186 ; AVX2-LABEL: test18
187 ; AVX2: vmaskmovps
188 ; AVX2-NOT: blend
189 define <2 x float> @test18(<2 x i32> %trigger, <2 x float>* %addr) {
190 %mask = icmp eq <2 x i32> %trigger, zeroinitializer
191 %res = call <2 x float> @llvm.masked.load.v2f32(<2 x float>* %addr, i32 4, <2 x i1>%mask, <2 x float>undef)
192 ret <2 x float> %res
193 }
194
195
151196 declare <16 x i32> @llvm.masked.load.v16i32(<16 x i32>*, i32, <16 x i1>, <16 x i32>)
152197 declare <4 x i32> @llvm.masked.load.v4i32(<4 x i32>*, i32, <4 x i1>, <4 x i32>)
198 declare <2 x i32> @llvm.masked.load.v2i32(<2 x i32>*, i32, <2 x i1>, <2 x i32>)
153199 declare void @llvm.masked.store.v16i32(<16 x i32>, <16 x i32>*, i32, <16 x i1>)
154200 declare void @llvm.masked.store.v8i32(<8 x i32>, <8 x i32>*, i32, <8 x i1>)
155201 declare void @llvm.masked.store.v4i32(<4 x i32>, <4 x i32>*, i32, <4 x i1>)
202 declare void @llvm.masked.store.v2f32(<2 x float>, <2 x float>*, i32, <2 x i1>)
203 declare void @llvm.masked.store.v2i32(<2 x i32>, <2 x i32>*, i32, <2 x i1>)
156204 declare void @llvm.masked.store.v16f32(<16 x float>, <16 x float>*, i32, <16 x i1>)
157205 declare void @llvm.masked.store.v16f32p(<16 x float>*, <16 x float>**, i32, <16 x i1>)
158206 declare <16 x float> @llvm.masked.load.v16f32(<16 x float>*, i32, <16 x i1>, <16 x float>)
159207 declare <8 x float> @llvm.masked.load.v8f32(<8 x float>*, i32, <8 x i1>, <8 x float>)
160208 declare <4 x float> @llvm.masked.load.v4f32(<4 x float>*, i32, <4 x i1>, <4 x float>)
209 declare <2 x float> @llvm.masked.load.v2f32(<2 x float>*, i32, <2 x i1>, <2 x float>)
161210 declare <8 x double> @llvm.masked.load.v8f64(<8 x double>*, i32, <8 x i1>, <8 x double>)
162211 declare <4 x double> @llvm.masked.load.v4f64(<4 x double>*, i32, <4 x i1>, <4 x double>)
163212 declare <2 x double> @llvm.masked.load.v2f64(<2 x double>*, i32, <2 x i1>, <2 x double>)