llvm.org GIT mirror llvm / d5fb62a
[SCEV] limit recursion depth of CompareSCEVComplexity Summary: CompareSCEVComplexity goes too deep (50+ on a quite a big unrolled loop) and runs almost infinite time. Added cache of "equal" SCEV pairs to earlier cutoff of further estimation. Recursion depth limit was also introduced as a parameter. Reviewers: sanjoy Subscribers: mzolotukhin, tstellarAMD, llvm-commits Differential Revision: https://reviews.llvm.org/D26389 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@287232 91177308-0d34-0410-b5e6-96231b3b80d8 Daniil Fukalov 2 years ago
2 changed file(s) with 114 addition(s) and 20 deletion(s). Raw diff Collapse all Expand all
126126 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
127127 cl::init(1000));
128128
129 static cl::opt
130 MaxCompareDepth("scalar-evolution-max-compare-depth", cl::Hidden,
131 cl::desc("Maximum depth of recursive compare complexity"),
132 cl::init(32));
133
129134 //===----------------------------------------------------------------------===//
130135 // SCEV class definitions
131136 //===----------------------------------------------------------------------===//
474479 static int
475480 CompareValueComplexity(SmallSet, 8> &EqCache,
476481 const LoopInfo *const LI, Value *LV, Value *RV,
477 unsigned DepthLeft = 2) {
478 if (DepthLeft == 0 || EqCache.count({LV, RV}))
482 unsigned Depth) {
483 if (Depth > MaxCompareDepth || EqCache.count({LV, RV}))
479484 return 0;
480485
481486 // Order pointer values after integer values. This helps SCEVExpander form
536541 for (unsigned Idx : seq(0u, LNumOps)) {
537542 int Result =
538543 CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx),
539 RInst->getOperand(Idx), DepthLeft - 1);
544 RInst->getOperand(Idx), Depth + 1);
540545 if (Result != 0)
541546 return Result;
542 EqCache.insert({LV, RV});
543 }
544 }
545
547 }
548 }
549
550 EqCache.insert({LV, RV});
546551 return 0;
547552 }
548553
549554 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
550555 // than RHS, respectively. A three-way result allows recursive comparisons to be
551556 // more efficient.
552 static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
553 const SCEV *RHS) {
557 static int CompareSCEVComplexity(
558 SmallSet, 8> &EqCacheSCEV,
559 const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
560 unsigned Depth = 0) {
554561 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
555562 if (LHS == RHS)
556563 return 0;
560567 if (LType != RType)
561568 return (int)LType - (int)RType;
562569
570 if (Depth > MaxCompareDepth || EqCacheSCEV.count({LHS, RHS}))
571 return 0;
563572 // Aside from the getSCEVType() ordering, the particular ordering
564573 // isn't very important except that it's beneficial to be consistent,
565574 // so that (a + b) and (b + a) don't end up as different expressions.
569578 const SCEVUnknown *RU = cast(RHS);
570579
571580 SmallSet, 8> EqCache;
572 return CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue());
581 int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(),
582 Depth + 1);
583 if (X == 0)
584 EqCacheSCEV.insert({LHS, RHS});
585 return X;
573586 }
574587
575588 case scConstant: {
604617
605618 // Lexicographically compare.
606619 for (unsigned i = 0; i != LNumOps; ++i) {
607 long X = CompareSCEVComplexity(LI, LA->getOperand(i), RA->getOperand(i));
620 int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i),
621 RA->getOperand(i), Depth + 1);
608622 if (X != 0)
609623 return X;
610624 }
611
625 EqCacheSCEV.insert({LHS, RHS});
612626 return 0;
613627 }
614628
627641 for (unsigned i = 0; i != LNumOps; ++i) {
628642 if (i >= RNumOps)
629643 return 1;
630 long X = CompareSCEVComplexity(LI, LC->getOperand(i), RC->getOperand(i));
644 int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i),
645 RC->getOperand(i), Depth + 1);
631646 if (X != 0)
632647 return X;
633648 }
634 return (int)LNumOps - (int)RNumOps;
649 EqCacheSCEV.insert({LHS, RHS});
650 return 0;
635651 }
636652
637653 case scUDivExpr: {
639655 const SCEVUDivExpr *RC = cast(RHS);
640656
641657 // Lexicographically compare udiv expressions.
642 long X = CompareSCEVComplexity(LI, LC->getLHS(), RC->getLHS());
658 int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(),
659 Depth + 1);
643660 if (X != 0)
644661 return X;
645 return CompareSCEVComplexity(LI, LC->getRHS(), RC->getRHS());
662 X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(),
663 Depth + 1);
664 if (X == 0)
665 EqCacheSCEV.insert({LHS, RHS});
666 return X;
646667 }
647668
648669 case scTruncate:
652673 const SCEVCastExpr *RC = cast(RHS);
653674
654675 // Compare cast expressions by operand.
655 return CompareSCEVComplexity(LI, LC->getOperand(), RC->getOperand());
676 int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(),
677 RC->getOperand(), Depth + 1);
678 if (X == 0)
679 EqCacheSCEV.insert({LHS, RHS});
680 return X;
656681 }
657682
658683 case scCouldNotCompute:
674699 static void GroupByComplexity(SmallVectorImpl &Ops,
675700 LoopInfo *LI) {
676701 if (Ops.size() < 2) return; // Noop
702
703 SmallSet, 8> EqCache;
677704 if (Ops.size() == 2) {
678705 // This is the common case, which also happens to be trivially simple.
679706 // Special case it.
680707 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
681 if (CompareSCEVComplexity(LI, RHS, LHS) < 0)
708 if (CompareSCEVComplexity(EqCache, LI, RHS, LHS) < 0)
682709 std::swap(LHS, RHS);
683710 return;
684711 }
685712
686713 // Do the rough sort by complexity.
687714 std::stable_sort(Ops.begin(), Ops.end(),
688 [LI](const SCEV *LHS, const SCEV *RHS) {
689 return CompareSCEVComplexity(LI, LHS, RHS) < 0;
715 [&EqCache, LI](const SCEV *LHS, const SCEV *RHS) {
716 return CompareSCEVComplexity(EqCache, LI, LHS, RHS) < 0;
690717 });
691718
692719 // Now that we are sorted by complexity, group elements of the same
464464 });
465465 }
466466
467 TEST_F(ScalarEvolutionsTest, SCEVCompareComplexity) {
468 FunctionType *FTy =
469 FunctionType::get(Type::getVoidTy(Context), std::vector(), false);
470 Function *F = cast(M.getOrInsertFunction("f", FTy));
471 BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
472 BasicBlock *LoopBB = BasicBlock::Create(Context, "bb1", F);
473 BranchInst::Create(LoopBB, EntryBB);
474
475 auto *Ty = Type::getInt32Ty(Context);
476 SmallVector Muls(8), Acc(8), NextAcc(8);
477
478 Acc[0] = PHINode::Create(Ty, 2, "", LoopBB);
479 Acc[1] = PHINode::Create(Ty, 2, "", LoopBB);
480 Acc[2] = PHINode::Create(Ty, 2, "", LoopBB);
481 Acc[3] = PHINode::Create(Ty, 2, "", LoopBB);
482 Acc[4] = PHINode::Create(Ty, 2, "", LoopBB);
483 Acc[5] = PHINode::Create(Ty, 2, "", LoopBB);
484 Acc[6] = PHINode::Create(Ty, 2, "", LoopBB);
485 Acc[7] = PHINode::Create(Ty, 2, "", LoopBB);
486
487 for (int i = 0; i < 20; i++) {
488 Muls[0] = BinaryOperator::CreateMul(Acc[0], Acc[0], "", LoopBB);
489 NextAcc[0] = BinaryOperator::CreateAdd(Muls[0], Acc[4], "", LoopBB);
490 Muls[1] = BinaryOperator::CreateMul(Acc[1], Acc[1], "", LoopBB);
491 NextAcc[1] = BinaryOperator::CreateAdd(Muls[1], Acc[5], "", LoopBB);
492 Muls[2] = BinaryOperator::CreateMul(Acc[2], Acc[2], "", LoopBB);
493 NextAcc[2] = BinaryOperator::CreateAdd(Muls[2], Acc[6], "", LoopBB);
494 Muls[3] = BinaryOperator::CreateMul(Acc[3], Acc[3], "", LoopBB);
495 NextAcc[3] = BinaryOperator::CreateAdd(Muls[3], Acc[7], "", LoopBB);
496
497 Muls[4] = BinaryOperator::CreateMul(Acc[4], Acc[4], "", LoopBB);
498 NextAcc[4] = BinaryOperator::CreateAdd(Muls[4], Acc[0], "", LoopBB);
499 Muls[5] = BinaryOperator::CreateMul(Acc[5], Acc[5], "", LoopBB);
500 NextAcc[5] = BinaryOperator::CreateAdd(Muls[5], Acc[1], "", LoopBB);
501 Muls[6] = BinaryOperator::CreateMul(Acc[6], Acc[6], "", LoopBB);
502 NextAcc[6] = BinaryOperator::CreateAdd(Muls[6], Acc[2], "", LoopBB);
503 Muls[7] = BinaryOperator::CreateMul(Acc[7], Acc[7], "", LoopBB);
504 NextAcc[7] = BinaryOperator::CreateAdd(Muls[7], Acc[3], "", LoopBB);
505 Acc = NextAcc;
506 }
507
508 auto II = LoopBB->begin();
509 for (int i = 0; i < 8; i++) {
510 PHINode *Phi = cast(&*II++);
511 Phi->addIncoming(Acc[i], LoopBB);
512 Phi->addIncoming(UndefValue::get(Ty), EntryBB);
513 }
514
515 BasicBlock *ExitBB = BasicBlock::Create(Context, "bb2", F);
516 BranchInst::Create(LoopBB, ExitBB, UndefValue::get(Type::getInt1Ty(Context)),
517 LoopBB);
518
519 Acc[0] = BinaryOperator::CreateAdd(Acc[0], Acc[1], "", ExitBB);
520 Acc[1] = BinaryOperator::CreateAdd(Acc[2], Acc[3], "", ExitBB);
521 Acc[2] = BinaryOperator::CreateAdd(Acc[4], Acc[5], "", ExitBB);
522 Acc[3] = BinaryOperator::CreateAdd(Acc[6], Acc[7], "", ExitBB);
523 Acc[0] = BinaryOperator::CreateAdd(Acc[0], Acc[1], "", ExitBB);
524 Acc[1] = BinaryOperator::CreateAdd(Acc[2], Acc[3], "", ExitBB);
525 Acc[0] = BinaryOperator::CreateAdd(Acc[0], Acc[1], "", ExitBB);
526
527 ReturnInst::Create(Context, nullptr, ExitBB);
528
529 ScalarEvolution SE = buildSE(*F);
530
531 EXPECT_NE(nullptr, SE.getSCEV(Acc[0]));
532 }
533
467534 } // end anonymous namespace
468535 } // end namespace llvm