llvm.org GIT mirror llvm / a3a48d9
add fast-math-flags to 'call' instructions (PR21290) This patch adds optional fast-math-flags (the same that apply to fmul/fadd/fsub/fdiv/frem/fcmp) to call instructions in IR. Follow-up patches would use these flags in LibCallSimplifier, add support to clang, and extend FMF to the DAG for calls. Motivating example: %y = fmul fast float %x, %x %z = tail call float @sqrtf(float %y) We'd like to be able to optimize sqrt(x*x) into fabs(x). We do this today using a function-wide attribute for unsafe-math, but we really want to trigger on the instructions themselves: %z = tail call fast float @sqrtf(float %y) because in an LTO build it's possible that calls with fast semantics have been inlined into a function with non-fast semantics. The code changes and tests are based on the recent commits that added "notail": http://reviews.llvm.org/rL252368 and added FMF to fcmp: http://reviews.llvm.org/rL241901 Differential Revision: http://reviews.llvm.org/D14707 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@255555 91177308-0d34-0410-b5e6-96231b3b80d8 Sanjay Patel 4 years ago
14 changed file(s) with 129 addition(s) and 45 deletion(s). Raw diff Collapse all Expand all
83138313
83148314 ::
83158315
8316 = [tail | musttail | notail ] call [cconv] [ret attrs] [*] () [fn attrs]
8316 = [tail | musttail | notail ] call [fast-math flags] [cconv] [ret attrs] [*] () [fn attrs]
83178317 [ operand bundles ]
83188318
83198319 Overview:
83698369 #. The optional ``notail`` marker indicates that the optimizers should not add
83708370 ``tail`` or ``musttail`` markers to the call. It is used to prevent tail
83718371 call optimization from being performed on the call.
8372
8373 #. The optional ``fast-math flags`` marker indicates that the call has one or more
8374 :ref:`fast-math flags `, which are optimization hints to enable
8375 otherwise unsafe floating-point optimizations. Fast-math flags are only valid
8376 for calls that return a floating-point scalar or vector type.
83728377
83738378 #. The optional "cconv" marker indicates which :ref:`calling
83748379 convention ` the call should use. If none is
347347 CALL_CCONV = 1,
348348 CALL_MUSTTAIL = 14,
349349 CALL_EXPLICIT_TYPE = 15,
350 CALL_NOTAIL = 16
350 CALL_NOTAIL = 16,
351 CALL_FMF = 17 // Call has optional fast-math-flags.
351352 };
352353
353354 // The function body block (FUNCTION_BLOCK_ID) describes function bodies. It
15311531 const Twine &Name = "") {
15321532 return Insert(CallInst::Create(Callee, Args, OpBundles), Name);
15331533 }
1534
15341535 CallInst *CreateCall(Value *Callee, ArrayRef Args,
1535 const Twine &Name) {
1536 return Insert(CallInst::Create(Callee, Args), Name);
1536 const Twine &Name, MDNode *FPMathTag = nullptr) {
1537 PointerType *PTy = cast(Callee->getType());
1538 FunctionType *FTy = cast(PTy->getElementType());
1539 return CreateCall(FTy, Callee, Args, Name, FPMathTag);
15371540 }
15381541
15391542 CallInst *CreateCall(llvm::FunctionType *FTy, Value *Callee,
1540 ArrayRef Args, const Twine &Name = "") {
1541 return Insert(CallInst::Create(FTy, Callee, Args), Name);
1543 ArrayRef Args, const Twine &Name = "",
1544 MDNode *FPMathTag = nullptr) {
1545 CallInst *CI = CallInst::Create(FTy, Callee, Args);
1546 if (isa(CI))
1547 CI = cast(AddFPMathAttributes(CI, FPMathTag, FMF));
1548 return Insert(CI, Name);
15421549 }
15431550
15441551 CallInst *CreateCall(Function *Callee, ArrayRef Args,
1545 const Twine &Name = "") {
1546 return CreateCall(Callee->getFunctionType(), Callee, Args, Name);
1552 const Twine &Name = "", MDNode *FPMathTag = nullptr) {
1553 return CreateCall(Callee->getFunctionType(), Callee, Args, Name, FPMathTag);
15471554 }
15481555
15491556 Value *CreateSelect(Value *C, Value *True, Value *False,
56025602 }
56035603
56045604 /// ParseCall
5605 /// ::= 'call' OptionalCallingConv OptionalAttrs Type Value
5606 /// ParameterList OptionalAttrs
5607 /// ::= 'tail' 'call' OptionalCallingConv OptionalAttrs Type Value
5608 /// ParameterList OptionalAttrs
5609 /// ::= 'musttail' 'call' OptionalCallingConv OptionalAttrs Type Value
5610 /// ParameterList OptionalAttrs
5611 /// ::= 'notail' 'call' OptionalCallingConv OptionalAttrs Type Value
5612 /// ParameterList OptionalAttrs
5605 /// ::= 'call' OptionalFastMathFlags OptionalCallingConv
5606 /// OptionalAttrs Type Value ParameterList OptionalAttrs
5607 /// ::= 'tail' 'call' OptionalFastMathFlags OptionalCallingConv
5608 /// OptionalAttrs Type Value ParameterList OptionalAttrs
5609 /// ::= 'musttail' 'call' OptionalFastMathFlags OptionalCallingConv
5610 /// OptionalAttrs Type Value ParameterList OptionalAttrs
5611 /// ::= 'notail' 'call' OptionalFastMathFlags OptionalCallingConv
5612 /// OptionalAttrs Type Value ParameterList OptionalAttrs
56135613 bool LLParser::ParseCall(Instruction *&Inst, PerFunctionState &PFS,
56145614 CallInst::TailCallKind TCK) {
56155615 AttrBuilder RetAttrs, FnAttrs;
56235623 SmallVector BundleList;
56245624 LocTy CallLoc = Lex.getLoc();
56255625
5626 if ((TCK != CallInst::TCK_None &&
5627 ParseToken(lltok::kw_call,
5628 "expected 'tail call', 'musttail call', or 'notail call'")) ||
5629 ParseOptionalCallingConv(CC) || ParseOptionalReturnAttrs(RetAttrs) ||
5626 if (TCK != CallInst::TCK_None &&
5627 ParseToken(lltok::kw_call,
5628 "expected 'tail call', 'musttail call', or 'notail call'"))
5629 return true;
5630
5631 FastMathFlags FMF = EatFastMathFlagsIfPresent();
5632
5633 if (ParseOptionalCallingConv(CC) || ParseOptionalReturnAttrs(RetAttrs) ||
56305634 ParseType(RetType, RetTypeLoc, true /*void allowed*/) ||
56315635 ParseValID(CalleeID) ||
56325636 ParseParameterList(ArgList, PFS, TCK == CallInst::TCK_MustTail,
56345638 ParseFnAttributeValuePairs(FnAttrs, FwdRefAttrGrps, false, BuiltinLoc) ||
56355639 ParseOptionalOperandBundles(BundleList, PFS))
56365640 return true;
5641
5642 if (FMF.any() && !RetType->isFPOrFPVectorTy())
5643 return Error(CallLoc, "fast-math-flags specified for call without "
5644 "floating-point scalar or vector return type");
56375645
56385646 // If RetType is a non-function pointer type, then this is the short syntax
56395647 // for the call, which means that RetType is just the return type. Infer the
57075715 CallInst *CI = CallInst::Create(Ty, Callee, Args, BundleList);
57085716 CI->setTailCallKind(TCK);
57095717 CI->setCallingConv(CC);
5718 if (FMF.any())
5719 CI->setFastMathFlags(FMF);
57105720 CI->setAttributes(PAL);
57115721 ForwardRefAttrGroups[CI] = FwdRefAttrGrps;
57125722 Inst = CI;
50055005 break;
50065006 }
50075007 case bitc::FUNC_CODE_INST_CALL: {
5008 // CALL: [paramattrs, cc, fnty, fnid, arg0, arg1...]
5008 // CALL: [paramattrs, cc, fmf, fnty, fnid, arg0, arg1...]
50095009 if (Record.size() < 3)
50105010 return error("Invalid record");
50115011
50125012 unsigned OpNum = 0;
50135013 AttributeSet PAL = getAttributes(Record[OpNum++]);
50145014 unsigned CCInfo = Record[OpNum++];
5015
5016 FastMathFlags FMF;
5017 if ((CCInfo >> bitc::CALL_FMF) & 1) {
5018 FMF = getDecodedFastMathFlags(Record[OpNum++]);
5019 if (!FMF.any())
5020 return error("Fast math flags indicator set for call with no FMF");
5021 }
50155022
50165023 FunctionType *FTy = nullptr;
50175024 if (CCInfo >> bitc::CALL_EXPLICIT_TYPE & 1 &&
50745081 TCK = CallInst::TCK_NoTail;
50755082 cast(I)->setTailCallKind(TCK);
50765083 cast(I)->setAttributes(PAL);
5084 if (FMF.any()) {
5085 if (!isa(I))
5086 return error("Fast-math-flags specified for call without "
5087 "floating-point scalar or vector return type");
5088 I->setFastMathFlags(FMF);
5089 }
50775090 break;
50785091 }
50795092 case bitc::FUNC_CODE_INST_VAARG: { // VAARG: [valistty, valist, instty]
21522152 Code = bitc::FUNC_CODE_INST_CALL;
21532153
21542154 Vals.push_back(VE.getAttributeID(CI.getAttributes()));
2155
2156 unsigned Flags = GetOptimizationFlags(&I);
21552157 Vals.push_back(CI.getCallingConv() << bitc::CALL_CCONV |
21562158 unsigned(CI.isTailCall()) << bitc::CALL_TAIL |
21572159 unsigned(CI.isMustTailCall()) << bitc::CALL_MUSTTAIL |
21582160 1 << bitc::CALL_EXPLICIT_TYPE |
2159 unsigned(CI.isNoTailCall()) << bitc::CALL_NOTAIL);
2161 unsigned(CI.isNoTailCall()) << bitc::CALL_NOTAIL |
2162 unsigned(Flags != 0) << bitc::CALL_FMF);
2163 if (Flags != 0)
2164 Vals.push_back(Flags);
2165
21602166 Vals.push_back(VE.getTypeID(FTy));
21612167 PushValueAndType(CI.getCalledValue(), InstID, Vals, VE); // Callee
21622168
663663 ret void
664664 }
665665
666 ; Check various fast math flags and floating-point types on calls.
667
668 declare float @fmf1()
669 declare double @fmf2()
670 declare <4 x double> @fmf3()
671
672 ; CHECK-LABEL: fastMathFlagsForCalls(
673 define void @fastMathFlagsForCalls(float %f, double %d1, <4 x double> %d2) {
674 %call.fast = call fast float @fmf1()
675 ; CHECK: %call.fast = call fast float @fmf1()
676
677 ; Throw in some other attributes to make sure those stay in the right places.
678
679 %call.nsz.arcp = notail call nsz arcp double @fmf2()
680 ; CHECK: %call.nsz.arcp = notail call nsz arcp double @fmf2()
681
682 %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf3()
683 ; CHECK: %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf3()
684
685 ret void
686 }
687
666688 ;; Type System
667689 %opaquety = type opaque
668690 define void @typesystem() {
569569 ret double %sqrt
570570
571571 ; CHECK-LABEL: sqrt_intrinsic_arg_squared(
572 ; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
572 ; CHECK-NEXT: %fabs = call fast double @llvm.fabs.f64(double %x)
573573 ; CHECK-NEXT: ret double %fabs
574574 }
575575
583583 ret double %sqrt
584584
585585 ; CHECK-LABEL: sqrt_intrinsic_three_args1(
586 ; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
587 ; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
586 ; CHECK-NEXT: %fabs = call fast double @llvm.fabs.f64(double %x)
587 ; CHECK-NEXT: %sqrt1 = call fast double @llvm.sqrt.f64(double %y)
588588 ; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
589589 ; CHECK-NEXT: ret double %1
590590 }
596596 ret double %sqrt
597597
598598 ; CHECK-LABEL: sqrt_intrinsic_three_args2(
599 ; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
600 ; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
599 ; CHECK-NEXT: %fabs = call fast double @llvm.fabs.f64(double %x)
600 ; CHECK-NEXT: %sqrt1 = call fast double @llvm.sqrt.f64(double %y)
601601 ; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
602602 ; CHECK-NEXT: ret double %1
603603 }
609609 ret double %sqrt
610610
611611 ; CHECK-LABEL: sqrt_intrinsic_three_args3(
612 ; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
613 ; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
612 ; CHECK-NEXT: %fabs = call fast double @llvm.fabs.f64(double %x)
613 ; CHECK-NEXT: %sqrt1 = call fast double @llvm.sqrt.f64(double %y)
614614 ; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
615615 ; CHECK-NEXT: ret double %1
616616 }
622622 ret double %sqrt
623623
624624 ; CHECK-LABEL: sqrt_intrinsic_three_args4(
625 ; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
626 ; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
625 ; CHECK-NEXT: %fabs = call fast double @llvm.fabs.f64(double %x)
626 ; CHECK-NEXT: %sqrt1 = call fast double @llvm.sqrt.f64(double %y)
627627 ; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
628628 ; CHECK-NEXT: ret double %1
629629 }
635635 ret double %sqrt
636636
637637 ; CHECK-LABEL: sqrt_intrinsic_three_args5(
638 ; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
639 ; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
638 ; CHECK-NEXT: %fabs = call fast double @llvm.fabs.f64(double %x)
639 ; CHECK-NEXT: %sqrt1 = call fast double @llvm.sqrt.f64(double %y)
640640 ; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
641641 ; CHECK-NEXT: ret double %1
642642 }
648648 ret double %sqrt
649649
650650 ; CHECK-LABEL: sqrt_intrinsic_three_args6(
651 ; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
652 ; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %y)
651 ; CHECK-NEXT: %fabs = call fast double @llvm.fabs.f64(double %x)
652 ; CHECK-NEXT: %sqrt1 = call fast double @llvm.sqrt.f64(double %y)
653653 ; CHECK-NEXT: %1 = fmul fast double %fabs, %sqrt1
654654 ; CHECK-NEXT: ret double %1
655655 }
674674
675675 ; CHECK-LABEL: sqrt_intrinsic_arg_5th(
676676 ; CHECK-NEXT: %mul = fmul fast double %x, %x
677 ; CHECK-NEXT: %sqrt1 = call double @llvm.sqrt.f64(double %x)
677 ; CHECK-NEXT: %sqrt1 = call fast double @llvm.sqrt.f64(double %x)
678678 ; CHECK-NEXT: %1 = fmul fast double %mul, %sqrt1
679679 ; CHECK-NEXT: ret double %1
680680 }
691691 ret float %sqrt
692692
693693 ; CHECK-LABEL: sqrt_call_squared_f32(
694 ; CHECK-NEXT: %fabs = call float @llvm.fabs.f32(float %x)
694 ; CHECK-NEXT: %fabs = call fast float @llvm.fabs.f32(float %x)
695695 ; CHECK-NEXT: ret float %fabs
696696 }
697697
701701 ret double %sqrt
702702
703703 ; CHECK-LABEL: sqrt_call_squared_f64(
704 ; CHECK-NEXT: %fabs = call double @llvm.fabs.f64(double %x)
704 ; CHECK-NEXT: %fabs = call fast double @llvm.fabs.f64(double %x)
705705 ; CHECK-NEXT: ret double %fabs
706706 }
707707
711711 ret fp128 %sqrt
712712
713713 ; CHECK-LABEL: sqrt_call_squared_f128(
714 ; CHECK-NEXT: %fabs = call fp128 @llvm.fabs.f128(fp128 %x)
714 ; CHECK-NEXT: %fabs = call fast fp128 @llvm.fabs.f128(fp128 %x)
715715 ; CHECK-NEXT: ret fp128 %fabs
716716 }
717717
88 ret float %call
99
1010 ; CHECK-LABEL: @foo(
11 ; CHECK-NEXT: call float @llvm.fabs.f32
11 ; CHECK-NEXT: call fast float @llvm.fabs.f32
1212 ; CHECK-NEXT: ret float
1313 }
1414
77 }
88
99 ; CHECK-LABEL: define double @mylog(
10 ; CHECK: %log = call double @log(double %x) #0
10 ; CHECK: %log = call fast double @log(double %x) #0
1111 ; CHECK: %mul = fmul fast double %log, %y
1212 ; CHECK: ret double %mul
1313 ; CHECK: }
99 ret float %call1
1010
1111 ; CHECK-LABEL: @bar(
12 ; CHECK-NEXT: call float @llvm.fabs.f32
12 ; CHECK-NEXT: call fast float @llvm.fabs.f32
1313 ; CHECK-NEXT: ret float
1414 }
1515
88
99 ; CHECK-LABEL: define double @mypow(
1010 ; CHECK: %mul = fmul fast double %x, %y
11 ; CHECK: %exp = call double @exp(double %mul) #0
11 ; CHECK: %exp = call fast double @exp(double %mul) #0
1212 ; CHECK: ret double %exp
1313 ; CHECK: }
1414
88
99 ; CHECK-LABEL: define double @mypow(
1010 ; CHECK: %mul = fmul fast double %x, %y
11 ; CHECK: %exp2 = call double @exp2(double %mul) #0
11 ; CHECK: %exp2 = call fast double @exp2(double %mul) #0
1212 ; CHECK: ret double %exp2
1313 ; CHECK: }
1414
130130 TEST_F(IRBuilderTest, FastMathFlags) {
131131 IRBuilder<> Builder(BB);
132132 Value *F, *FC;
133 Instruction *FDiv, *FAdd, *FCmp;
133 Instruction *FDiv, *FAdd, *FCmp, *FCall;
134134
135135 F = Builder.CreateLoad(GV);
136136 F = Builder.CreateFAdd(F, F);
204204 ASSERT_TRUE(isa(FC));
205205 FCmp = cast(FC);
206206 EXPECT_TRUE(FCmp->hasAllowReciprocal());
207
208 Builder.clearFastMathFlags();
209
210 // Test a call with FMF.
211 auto CalleeTy = FunctionType::get(Type::getFloatTy(Ctx),
212 /*isVarArg=*/false);
213 auto Callee =
214 Function::Create(CalleeTy, Function::ExternalLinkage, "", M.get());
215
216 FCall = Builder.CreateCall(Callee, None);
217 EXPECT_FALSE(FCall->hasNoNaNs());
218
219 FMF.clear();
220 FMF.setNoNaNs();
221 Builder.SetFastMathFlags(FMF);
222
223 FCall = Builder.CreateCall(Callee, None);
224 EXPECT_TRUE(Builder.getFastMathFlags().any());
225 EXPECT_TRUE(Builder.getFastMathFlags().NoNaNs);
226 EXPECT_TRUE(FCall->hasNoNaNs());
207227
208228 Builder.clearFastMathFlags();
209229