LLVM 19.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
84#include "llvm/Config/llvm-config.h"
85#include "llvm/IR/Argument.h"
86#include "llvm/IR/BasicBlock.h"
87#include "llvm/IR/CFG.h"
88#include "llvm/IR/Constant.h"
90#include "llvm/IR/Constants.h"
91#include "llvm/IR/DataLayout.h"
93#include "llvm/IR/Dominators.h"
94#include "llvm/IR/Function.h"
95#include "llvm/IR/GlobalAlias.h"
96#include "llvm/IR/GlobalValue.h"
98#include "llvm/IR/InstrTypes.h"
99#include "llvm/IR/Instruction.h"
100#include "llvm/IR/Instructions.h"
102#include "llvm/IR/Intrinsics.h"
103#include "llvm/IR/LLVMContext.h"
104#include "llvm/IR/Operator.h"
105#include "llvm/IR/PatternMatch.h"
106#include "llvm/IR/Type.h"
107#include "llvm/IR/Use.h"
108#include "llvm/IR/User.h"
109#include "llvm/IR/Value.h"
110#include "llvm/IR/Verifier.h"
112#include "llvm/Pass.h"
113#include "llvm/Support/Casting.h"
116#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136
137#define DEBUG_TYPE "scalar-evolution"
138
139STATISTIC(NumExitCountsComputed,
140 "Number of loop exits with predictable exit counts");
141STATISTIC(NumExitCountsNotComputed,
142 "Number of loop exits without predictable exit counts");
143STATISTIC(NumBruteForceTripCountsComputed,
144 "Number of loops with trip counts computed by force");
145
146#ifdef EXPENSIVE_CHECKS
147bool llvm::VerifySCEV = true;
148#else
149bool llvm::VerifySCEV = false;
150#endif
151
153 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154 cl::desc("Maximum number of iterations SCEV will "
155 "symbolically execute a constant "
156 "derived loop"),
157 cl::init(100));
158
160 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
161 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
163 "verify-scev-strict", cl::Hidden,
164 cl::desc("Enable stricter verification with -verify-scev is passed"));
165
167 "scev-verify-ir", cl::Hidden,
168 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
169 cl::init(false));
170
172 "scev-mulops-inline-threshold", cl::Hidden,
173 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
174 cl::init(32));
175
177 "scev-addops-inline-threshold", cl::Hidden,
178 cl::desc("Threshold for inlining addition operands into a SCEV"),
179 cl::init(500));
180
182 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
183 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
184 cl::init(32));
185
187 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
188 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
189 cl::init(2));
190
192 "scalar-evolution-max-value-compare-depth", cl::Hidden,
193 cl::desc("Maximum depth of recursive value complexity comparisons"),
194 cl::init(2));
195
197 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
198 cl::desc("Maximum depth of recursive arithmetics"),
199 cl::init(32));
200
202 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
203 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
204
206 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
207 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
208 cl::init(8));
209
211 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
212 cl::desc("Max coefficients in AddRec during evolving"),
213 cl::init(8));
214
216 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
217 cl::desc("Size of the expression which is considered huge"),
218 cl::init(4096));
219
221 "scev-range-iter-threshold", cl::Hidden,
222 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
223 cl::init(32));
224
225static cl::opt<bool>
226ClassifyExpressions("scalar-evolution-classify-expressions",
227 cl::Hidden, cl::init(true),
228 cl::desc("When printing analysis, include information on every instruction"));
229
231 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
232 cl::init(false),
233 cl::desc("Use more powerful methods of sharpening expression ranges. May "
234 "be costly in terms of compile time"));
235
237 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
238 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
239 "Phi strongly connected components"),
240 cl::init(8));
241
242static cl::opt<bool>
243 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
244 cl::desc("Handle <= and >= in finite loops"),
245 cl::init(true));
246
248 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250 cl::init(true));
251
252//===----------------------------------------------------------------------===//
253// SCEV class definitions
254//===----------------------------------------------------------------------===//
255
256//===----------------------------------------------------------------------===//
257// Implementation of the SCEV class.
258//
259
260#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
262 print(dbgs());
263 dbgs() << '\n';
264}
265#endif
266
268 switch (getSCEVType()) {
269 case scConstant:
270 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
271 return;
272 case scVScale:
273 OS << "vscale";
274 return;
275 case scPtrToInt: {
276 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
277 const SCEV *Op = PtrToInt->getOperand();
278 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
279 << *PtrToInt->getType() << ")";
280 return;
281 }
282 case scTruncate: {
283 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
284 const SCEV *Op = Trunc->getOperand();
285 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
286 << *Trunc->getType() << ")";
287 return;
288 }
289 case scZeroExtend: {
290 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
291 const SCEV *Op = ZExt->getOperand();
292 OS << "(zext " << *Op->getType() << " " << *Op << " to "
293 << *ZExt->getType() << ")";
294 return;
295 }
296 case scSignExtend: {
297 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
298 const SCEV *Op = SExt->getOperand();
299 OS << "(sext " << *Op->getType() << " " << *Op << " to "
300 << *SExt->getType() << ")";
301 return;
302 }
303 case scAddRecExpr: {
304 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
305 OS << "{" << *AR->getOperand(0);
306 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
307 OS << ",+," << *AR->getOperand(i);
308 OS << "}<";
309 if (AR->hasNoUnsignedWrap())
310 OS << "nuw><";
311 if (AR->hasNoSignedWrap())
312 OS << "nsw><";
313 if (AR->hasNoSelfWrap() &&
315 OS << "nw><";
316 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
317 OS << ">";
318 return;
319 }
320 case scAddExpr:
321 case scMulExpr:
322 case scUMaxExpr:
323 case scSMaxExpr:
324 case scUMinExpr:
325 case scSMinExpr:
327 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
328 const char *OpStr = nullptr;
329 switch (NAry->getSCEVType()) {
330 case scAddExpr: OpStr = " + "; break;
331 case scMulExpr: OpStr = " * "; break;
332 case scUMaxExpr: OpStr = " umax "; break;
333 case scSMaxExpr: OpStr = " smax "; break;
334 case scUMinExpr:
335 OpStr = " umin ";
336 break;
337 case scSMinExpr:
338 OpStr = " smin ";
339 break;
341 OpStr = " umin_seq ";
342 break;
343 default:
344 llvm_unreachable("There are no other nary expression types.");
345 }
346 OS << "(";
347 ListSeparator LS(OpStr);
348 for (const SCEV *Op : NAry->operands())
349 OS << LS << *Op;
350 OS << ")";
351 switch (NAry->getSCEVType()) {
352 case scAddExpr:
353 case scMulExpr:
354 if (NAry->hasNoUnsignedWrap())
355 OS << "<nuw>";
356 if (NAry->hasNoSignedWrap())
357 OS << "<nsw>";
358 break;
359 default:
360 // Nothing to print for other nary expressions.
361 break;
362 }
363 return;
364 }
365 case scUDivExpr: {
366 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
367 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
368 return;
369 }
370 case scUnknown:
371 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
372 return;
374 OS << "***COULDNOTCOMPUTE***";
375 return;
376 }
377 llvm_unreachable("Unknown SCEV kind!");
378}
379
381 switch (getSCEVType()) {
382 case scConstant:
383 return cast<SCEVConstant>(this)->getType();
384 case scVScale:
385 return cast<SCEVVScale>(this)->getType();
386 case scPtrToInt:
387 case scTruncate:
388 case scZeroExtend:
389 case scSignExtend:
390 return cast<SCEVCastExpr>(this)->getType();
391 case scAddRecExpr:
392 return cast<SCEVAddRecExpr>(this)->getType();
393 case scMulExpr:
394 return cast<SCEVMulExpr>(this)->getType();
395 case scUMaxExpr:
396 case scSMaxExpr:
397 case scUMinExpr:
398 case scSMinExpr:
399 return cast<SCEVMinMaxExpr>(this)->getType();
401 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
402 case scAddExpr:
403 return cast<SCEVAddExpr>(this)->getType();
404 case scUDivExpr:
405 return cast<SCEVUDivExpr>(this)->getType();
406 case scUnknown:
407 return cast<SCEVUnknown>(this)->getType();
409 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
410 }
411 llvm_unreachable("Unknown SCEV kind!");
412}
413
415 switch (getSCEVType()) {
416 case scConstant:
417 case scVScale:
418 case scUnknown:
419 return {};
420 case scPtrToInt:
421 case scTruncate:
422 case scZeroExtend:
423 case scSignExtend:
424 return cast<SCEVCastExpr>(this)->operands();
425 case scAddRecExpr:
426 case scAddExpr:
427 case scMulExpr:
428 case scUMaxExpr:
429 case scSMaxExpr:
430 case scUMinExpr:
431 case scSMinExpr:
433 return cast<SCEVNAryExpr>(this)->operands();
434 case scUDivExpr:
435 return cast<SCEVUDivExpr>(this)->operands();
437 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
438 }
439 llvm_unreachable("Unknown SCEV kind!");
440}
441
442bool SCEV::isZero() const {
443 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
444 return SC->getValue()->isZero();
445 return false;
446}
447
448bool SCEV::isOne() const {
449 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
450 return SC->getValue()->isOne();
451 return false;
452}
453
455 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
456 return SC->getValue()->isMinusOne();
457 return false;
458}
459
461 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
462 if (!Mul) return false;
463
464 // If there is a constant factor, it will be first.
465 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
466 if (!SC) return false;
467
468 // Return true if the value is negative, this matches things like (-42 * V).
469 return SC->getAPInt().isNegative();
470}
471
474
476 return S->getSCEVType() == scCouldNotCompute;
477}
478
481 ID.AddInteger(scConstant);
482 ID.AddPointer(V);
483 void *IP = nullptr;
484 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
485 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
486 UniqueSCEVs.InsertNode(S, IP);
487 return S;
488}
489
491 return getConstant(ConstantInt::get(getContext(), Val));
492}
493
494const SCEV *
496 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
497 return getConstant(ConstantInt::get(ITy, V, isSigned));
498}
499
502 ID.AddInteger(scVScale);
503 ID.AddPointer(Ty);
504 void *IP = nullptr;
505 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
506 return S;
507 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
508 UniqueSCEVs.InsertNode(S, IP);
509 return S;
510}
511
513 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
514 if (EC.isScalable())
515 Res = getMulExpr(Res, getVScale(Ty));
516 return Res;
517}
518
520 const SCEV *op, Type *ty)
521 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
522
523SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
524 Type *ITy)
525 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
527 "Must be a non-bit-width-changing pointer-to-integer cast!");
528}
529
531 SCEVTypes SCEVTy, const SCEV *op,
532 Type *ty)
533 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
534
535SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
536 Type *ty)
538 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
539 "Cannot truncate non-integer value!");
540}
541
542SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
543 const SCEV *op, Type *ty)
545 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
546 "Cannot zero extend non-integer value!");
547}
548
549SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
550 const SCEV *op, Type *ty)
552 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
553 "Cannot sign extend non-integer value!");
554}
555
556void SCEVUnknown::deleted() {
557 // Clear this SCEVUnknown from various maps.
558 SE->forgetMemoizedResults(this);
559
560 // Remove this SCEVUnknown from the uniquing map.
561 SE->UniqueSCEVs.RemoveNode(this);
562
563 // Release the value.
564 setValPtr(nullptr);
565}
566
567void SCEVUnknown::allUsesReplacedWith(Value *New) {
568 // Clear this SCEVUnknown from various maps.
569 SE->forgetMemoizedResults(this);
570
571 // Remove this SCEVUnknown from the uniquing map.
572 SE->UniqueSCEVs.RemoveNode(this);
573
574 // Replace the value pointer in case someone is still using this SCEVUnknown.
575 setValPtr(New);
576}
577
578//===----------------------------------------------------------------------===//
579// SCEV Utilities
580//===----------------------------------------------------------------------===//
581
582/// Compare the two values \p LV and \p RV in terms of their "complexity" where
583/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
584/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
585/// have been previously deemed to be "equally complex" by this routine. It is
586/// intended to avoid exponential time complexity in cases like:
587///
588/// %a = f(%x, %y)
589/// %b = f(%a, %a)
590/// %c = f(%b, %b)
591///
592/// %d = f(%x, %y)
593/// %e = f(%d, %d)
594/// %f = f(%e, %e)
595///
596/// CompareValueComplexity(%f, %c)
597///
598/// Since we do not continue running this routine on expression trees once we
599/// have seen unequal values, there is no need to track them in the cache.
600static int
602 const LoopInfo *const LI, Value *LV, Value *RV,
603 unsigned Depth) {
604 if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
605 return 0;
606
607 // Order pointer values after integer values. This helps SCEVExpander form
608 // GEPs.
609 bool LIsPointer = LV->getType()->isPointerTy(),
610 RIsPointer = RV->getType()->isPointerTy();
611 if (LIsPointer != RIsPointer)
612 return (int)LIsPointer - (int)RIsPointer;
613
614 // Compare getValueID values.
615 unsigned LID = LV->getValueID(), RID = RV->getValueID();
616 if (LID != RID)
617 return (int)LID - (int)RID;
618
619 // Sort arguments by their position.
620 if (const auto *LA = dyn_cast<Argument>(LV)) {
621 const auto *RA = cast<Argument>(RV);
622 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
623 return (int)LArgNo - (int)RArgNo;
624 }
625
626 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
627 const auto *RGV = cast<GlobalValue>(RV);
628
629 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
630 auto LT = GV->getLinkage();
631 return !(GlobalValue::isPrivateLinkage(LT) ||
633 };
634
635 // Use the names to distinguish the two values, but only if the
636 // names are semantically important.
637 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
638 return LGV->getName().compare(RGV->getName());
639 }
640
641 // For instructions, compare their loop depth, and their operand count. This
642 // is pretty loose.
643 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
644 const auto *RInst = cast<Instruction>(RV);
645
646 // Compare loop depths.
647 const BasicBlock *LParent = LInst->getParent(),
648 *RParent = RInst->getParent();
649 if (LParent != RParent) {
650 unsigned LDepth = LI->getLoopDepth(LParent),
651 RDepth = LI->getLoopDepth(RParent);
652 if (LDepth != RDepth)
653 return (int)LDepth - (int)RDepth;
654 }
655
656 // Compare the number of operands.
657 unsigned LNumOps = LInst->getNumOperands(),
658 RNumOps = RInst->getNumOperands();
659 if (LNumOps != RNumOps)
660 return (int)LNumOps - (int)RNumOps;
661
662 for (unsigned Idx : seq(LNumOps)) {
663 int Result =
664 CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
665 RInst->getOperand(Idx), Depth + 1);
666 if (Result != 0)
667 return Result;
668 }
669 }
670
671 EqCacheValue.unionSets(LV, RV);
672 return 0;
673}
674
675// Return negative, zero, or positive, if LHS is less than, equal to, or greater
676// than RHS, respectively. A three-way result allows recursive comparisons to be
677// more efficient.
678// If the max analysis depth was reached, return std::nullopt, assuming we do
679// not know if they are equivalent for sure.
680static std::optional<int>
683 const LoopInfo *const LI, const SCEV *LHS,
684 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
685 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
686 if (LHS == RHS)
687 return 0;
688
689 // Primarily, sort the SCEVs by their getSCEVType().
690 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
691 if (LType != RType)
692 return (int)LType - (int)RType;
693
694 if (EqCacheSCEV.isEquivalent(LHS, RHS))
695 return 0;
696
698 return std::nullopt;
699
700 // Aside from the getSCEVType() ordering, the particular ordering
701 // isn't very important except that it's beneficial to be consistent,
702 // so that (a + b) and (b + a) don't end up as different expressions.
703 switch (LType) {
704 case scUnknown: {
705 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
706 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
707
708 int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
709 RU->getValue(), Depth + 1);
710 if (X == 0)
711 EqCacheSCEV.unionSets(LHS, RHS);
712 return X;
713 }
714
715 case scConstant: {
716 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
717 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
718
719 // Compare constant values.
720 const APInt &LA = LC->getAPInt();
721 const APInt &RA = RC->getAPInt();
722 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
723 if (LBitWidth != RBitWidth)
724 return (int)LBitWidth - (int)RBitWidth;
725 return LA.ult(RA) ? -1 : 1;
726 }
727
728 case scVScale: {
729 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
730 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
731 return LTy->getBitWidth() - RTy->getBitWidth();
732 }
733
734 case scAddRecExpr: {
735 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
736 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
737
738 // There is always a dominance between two recs that are used by one SCEV,
739 // so we can safely sort recs by loop header dominance. We require such
740 // order in getAddExpr.
741 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
742 if (LLoop != RLoop) {
743 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
744 assert(LHead != RHead && "Two loops share the same header?");
745 if (DT.dominates(LHead, RHead))
746 return 1;
747 assert(DT.dominates(RHead, LHead) &&
748 "No dominance between recurrences used by one SCEV?");
749 return -1;
750 }
751
752 [[fallthrough]];
753 }
754
755 case scTruncate:
756 case scZeroExtend:
757 case scSignExtend:
758 case scPtrToInt:
759 case scAddExpr:
760 case scMulExpr:
761 case scUDivExpr:
762 case scSMaxExpr:
763 case scUMaxExpr:
764 case scSMinExpr:
765 case scUMinExpr:
767 ArrayRef<const SCEV *> LOps = LHS->operands();
768 ArrayRef<const SCEV *> ROps = RHS->operands();
769
770 // Lexicographically compare n-ary-like expressions.
771 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
772 if (LNumOps != RNumOps)
773 return (int)LNumOps - (int)RNumOps;
774
775 for (unsigned i = 0; i != LNumOps; ++i) {
776 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LOps[i],
777 ROps[i], DT, Depth + 1);
778 if (X != 0)
779 return X;
780 }
781 EqCacheSCEV.unionSets(LHS, RHS);
782 return 0;
783 }
784
786 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
787 }
788 llvm_unreachable("Unknown SCEV kind!");
789}
790
791/// Given a list of SCEV objects, order them by their complexity, and group
792/// objects of the same complexity together by value. When this routine is
793/// finished, we know that any duplicates in the vector are consecutive and that
794/// complexity is monotonically increasing.
795///
796/// Note that we go take special precautions to ensure that we get deterministic
797/// results from this routine. In other words, we don't want the results of
798/// this to depend on where the addresses of various SCEV objects happened to
799/// land in memory.
801 LoopInfo *LI, DominatorTree &DT) {
802 if (Ops.size() < 2) return; // Noop
803
806
807 // Whether LHS has provably less complexity than RHS.
808 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
809 auto Complexity =
810 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
811 return Complexity && *Complexity < 0;
812 };
813 if (Ops.size() == 2) {
814 // This is the common case, which also happens to be trivially simple.
815 // Special case it.
816 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
817 if (IsLessComplex(RHS, LHS))
818 std::swap(LHS, RHS);
819 return;
820 }
821
822 // Do the rough sort by complexity.
823 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
824 return IsLessComplex(LHS, RHS);
825 });
826
827 // Now that we are sorted by complexity, group elements of the same
828 // complexity. Note that this is, at worst, N^2, but the vector is likely to
829 // be extremely short in practice. Note that we take this approach because we
830 // do not want to depend on the addresses of the objects we are grouping.
831 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
832 const SCEV *S = Ops[i];
833 unsigned Complexity = S->getSCEVType();
834
835 // If there are any objects of the same complexity and same value as this
836 // one, group them.
837 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
838 if (Ops[j] == S) { // Found a duplicate.
839 // Move it to immediately after i'th element.
840 std::swap(Ops[i+1], Ops[j]);
841 ++i; // no need to rescan it.
842 if (i == e-2) return; // Done!
843 }
844 }
845 }
846}
847
848/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
849/// least HugeExprThreshold nodes).
851 return any_of(Ops, [](const SCEV *S) {
853 });
854}
855
856//===----------------------------------------------------------------------===//
857// Simple SCEV method implementations
858//===----------------------------------------------------------------------===//
859
860/// Compute BC(It, K). The result has width W. Assume, K > 0.
861static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
862 ScalarEvolution &SE,
863 Type *ResultTy) {
864 // Handle the simplest case efficiently.
865 if (K == 1)
866 return SE.getTruncateOrZeroExtend(It, ResultTy);
867
868 // We are using the following formula for BC(It, K):
869 //
870 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
871 //
872 // Suppose, W is the bitwidth of the return value. We must be prepared for
873 // overflow. Hence, we must assure that the result of our computation is
874 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
875 // safe in modular arithmetic.
876 //
877 // However, this code doesn't use exactly that formula; the formula it uses
878 // is something like the following, where T is the number of factors of 2 in
879 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
880 // exponentiation:
881 //
882 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
883 //
884 // This formula is trivially equivalent to the previous formula. However,
885 // this formula can be implemented much more efficiently. The trick is that
886 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
887 // arithmetic. To do exact division in modular arithmetic, all we have
888 // to do is multiply by the inverse. Therefore, this step can be done at
889 // width W.
890 //
891 // The next issue is how to safely do the division by 2^T. The way this
892 // is done is by doing the multiplication step at a width of at least W + T
893 // bits. This way, the bottom W+T bits of the product are accurate. Then,
894 // when we perform the division by 2^T (which is equivalent to a right shift
895 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
896 // truncated out after the division by 2^T.
897 //
898 // In comparison to just directly using the first formula, this technique
899 // is much more efficient; using the first formula requires W * K bits,
900 // but this formula less than W + K bits. Also, the first formula requires
901 // a division step, whereas this formula only requires multiplies and shifts.
902 //
903 // It doesn't matter whether the subtraction step is done in the calculation
904 // width or the input iteration count's width; if the subtraction overflows,
905 // the result must be zero anyway. We prefer here to do it in the width of
906 // the induction variable because it helps a lot for certain cases; CodeGen
907 // isn't smart enough to ignore the overflow, which leads to much less
908 // efficient code if the width of the subtraction is wider than the native
909 // register width.
910 //
911 // (It's possible to not widen at all by pulling out factors of 2 before
912 // the multiplication; for example, K=2 can be calculated as
913 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
914 // extra arithmetic, so it's not an obvious win, and it gets
915 // much more complicated for K > 3.)
916
917 // Protection from insane SCEVs; this bound is conservative,
918 // but it probably doesn't matter.
919 if (K > 1000)
920 return SE.getCouldNotCompute();
921
922 unsigned W = SE.getTypeSizeInBits(ResultTy);
923
924 // Calculate K! / 2^T and T; we divide out the factors of two before
925 // multiplying for calculating K! / 2^T to avoid overflow.
926 // Other overflow doesn't matter because we only care about the bottom
927 // W bits of the result.
928 APInt OddFactorial(W, 1);
929 unsigned T = 1;
930 for (unsigned i = 3; i <= K; ++i) {
931 unsigned TwoFactors = countr_zero(i);
932 T += TwoFactors;
933 OddFactorial *= (i >> TwoFactors);
934 }
935
936 // We need at least W + T bits for the multiplication step
937 unsigned CalculationBits = W + T;
938
939 // Calculate 2^T, at width T+W.
940 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
941
942 // Calculate the multiplicative inverse of K! / 2^T;
943 // this multiplication factor will perform the exact division by
944 // K! / 2^T.
945 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
946
947 // Calculate the product, at width T+W
948 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
949 CalculationBits);
950 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
951 for (unsigned i = 1; i != K; ++i) {
952 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
953 Dividend = SE.getMulExpr(Dividend,
954 SE.getTruncateOrZeroExtend(S, CalculationTy));
955 }
956
957 // Divide by 2^T
958 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
959
960 // Truncate the result, and divide by K! / 2^T.
961
962 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
963 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
964}
965
966/// Return the value of this chain of recurrences at the specified iteration
967/// number. We can evaluate this recurrence by multiplying each element in the
968/// chain by the binomial coefficient corresponding to it. In other words, we
969/// can evaluate {A,+,B,+,C,+,D} as:
970///
971/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
972///
973/// where BC(It, k) stands for binomial coefficient.
975 ScalarEvolution &SE) const {
976 return evaluateAtIteration(operands(), It, SE);
977}
978
979const SCEV *
981 const SCEV *It, ScalarEvolution &SE) {
982 assert(Operands.size() > 0);
983 const SCEV *Result = Operands[0];
984 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
985 // The computation is correct in the face of overflow provided that the
986 // multiplication is performed _after_ the evaluation of the binomial
987 // coefficient.
988 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
989 if (isa<SCEVCouldNotCompute>(Coeff))
990 return Coeff;
991
992 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
993 }
994 return Result;
995}
996
997//===----------------------------------------------------------------------===//
998// SCEV Expression folder implementations
999//===----------------------------------------------------------------------===//
1000
1002 unsigned Depth) {
1003 assert(Depth <= 1 &&
1004 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1005
1006 // We could be called with an integer-typed operands during SCEV rewrites.
1007 // Since the operand is an integer already, just perform zext/trunc/self cast.
1008 if (!Op->getType()->isPointerTy())
1009 return Op;
1010
1011 // What would be an ID for such a SCEV cast expression?
1013 ID.AddInteger(scPtrToInt);
1014 ID.AddPointer(Op);
1015
1016 void *IP = nullptr;
1017
1018 // Is there already an expression for such a cast?
1019 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1020 return S;
1021
1022 // It isn't legal for optimizations to construct new ptrtoint expressions
1023 // for non-integral pointers.
1024 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1025 return getCouldNotCompute();
1026
1027 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1028
1029 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1030 // is sufficiently wide to represent all possible pointer values.
1031 // We could theoretically teach SCEV to truncate wider pointers, but
1032 // that isn't implemented for now.
1034 getDataLayout().getTypeSizeInBits(IntPtrTy))
1035 return getCouldNotCompute();
1036
1037 // If not, is this expression something we can't reduce any further?
1038 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1039 // Perform some basic constant folding. If the operand of the ptr2int cast
1040 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1041 // left as-is), but produce a zero constant.
1042 // NOTE: We could handle a more general case, but lack motivational cases.
1043 if (isa<ConstantPointerNull>(U->getValue()))
1044 return getZero(IntPtrTy);
1045
1046 // Create an explicit cast node.
1047 // We can reuse the existing insert position since if we get here,
1048 // we won't have made any changes which would invalidate it.
1049 SCEV *S = new (SCEVAllocator)
1050 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1051 UniqueSCEVs.InsertNode(S, IP);
1052 registerUser(S, Op);
1053 return S;
1054 }
1055
1056 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1057 "non-SCEVUnknown's.");
1058
1059 // Otherwise, we've got some expression that is more complex than just a
1060 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1061 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1062 // only, and the expressions must otherwise be integer-typed.
1063 // So sink the cast down to the SCEVUnknown's.
1064
1065 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1066 /// which computes a pointer-typed value, and rewrites the whole expression
1067 /// tree so that *all* the computations are done on integers, and the only
1068 /// pointer-typed operands in the expression are SCEVUnknown.
1069 class SCEVPtrToIntSinkingRewriter
1070 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1072
1073 public:
1074 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1075
1076 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1077 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1078 return Rewriter.visit(Scev);
1079 }
1080
1081 const SCEV *visit(const SCEV *S) {
1082 Type *STy = S->getType();
1083 // If the expression is not pointer-typed, just keep it as-is.
1084 if (!STy->isPointerTy())
1085 return S;
1086 // Else, recursively sink the cast down into it.
1087 return Base::visit(S);
1088 }
1089
1090 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1092 bool Changed = false;
1093 for (const auto *Op : Expr->operands()) {
1094 Operands.push_back(visit(Op));
1095 Changed |= Op != Operands.back();
1096 }
1097 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1098 }
1099
1100 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1102 bool Changed = false;
1103 for (const auto *Op : Expr->operands()) {
1104 Operands.push_back(visit(Op));
1105 Changed |= Op != Operands.back();
1106 }
1107 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1108 }
1109
1110 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1111 assert(Expr->getType()->isPointerTy() &&
1112 "Should only reach pointer-typed SCEVUnknown's.");
1113 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1114 }
1115 };
1116
1117 // And actually perform the cast sinking.
1118 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1119 assert(IntOp->getType()->isIntegerTy() &&
1120 "We must have succeeded in sinking the cast, "
1121 "and ending up with an integer-typed expression!");
1122 return IntOp;
1123}
1124
1126 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1127
1128 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1129 if (isa<SCEVCouldNotCompute>(IntOp))
1130 return IntOp;
1131
1132 return getTruncateOrZeroExtend(IntOp, Ty);
1133}
1134
1136 unsigned Depth) {
1137 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1138 "This is not a truncating conversion!");
1139 assert(isSCEVable(Ty) &&
1140 "This is not a conversion to a SCEVable type!");
1141 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1142 Ty = getEffectiveSCEVType(Ty);
1143
1145 ID.AddInteger(scTruncate);
1146 ID.AddPointer(Op);
1147 ID.AddPointer(Ty);
1148 void *IP = nullptr;
1149 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1150
1151 // Fold if the operand is constant.
1152 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1153 return getConstant(
1154 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1155
1156 // trunc(trunc(x)) --> trunc(x)
1157 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1158 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1159
1160 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1161 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1162 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1163
1164 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1165 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1166 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1167
1168 if (Depth > MaxCastDepth) {
1169 SCEV *S =
1170 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1171 UniqueSCEVs.InsertNode(S, IP);
1172 registerUser(S, Op);
1173 return S;
1174 }
1175
1176 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1177 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1178 // if after transforming we have at most one truncate, not counting truncates
1179 // that replace other casts.
1180 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1181 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1183 unsigned numTruncs = 0;
1184 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1185 ++i) {
1186 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1187 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1188 isa<SCEVTruncateExpr>(S))
1189 numTruncs++;
1190 Operands.push_back(S);
1191 }
1192 if (numTruncs < 2) {
1193 if (isa<SCEVAddExpr>(Op))
1194 return getAddExpr(Operands);
1195 if (isa<SCEVMulExpr>(Op))
1196 return getMulExpr(Operands);
1197 llvm_unreachable("Unexpected SCEV type for Op.");
1198 }
1199 // Although we checked in the beginning that ID is not in the cache, it is
1200 // possible that during recursion and different modification ID was inserted
1201 // into the cache. So if we find it, just return it.
1202 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1203 return S;
1204 }
1205
1206 // If the input value is a chrec scev, truncate the chrec's operands.
1207 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1209 for (const SCEV *Op : AddRec->operands())
1210 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1211 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1212 }
1213
1214 // Return zero if truncating to known zeros.
1215 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1216 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1217 return getZero(Ty);
1218
1219 // The cast wasn't folded; create an explicit cast node. We can reuse
1220 // the existing insert position since if we get here, we won't have
1221 // made any changes which would invalidate it.
1222 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1223 Op, Ty);
1224 UniqueSCEVs.InsertNode(S, IP);
1225 registerUser(S, Op);
1226 return S;
1227}
1228
1229// Get the limit of a recurrence such that incrementing by Step cannot cause
1230// signed overflow as long as the value of the recurrence within the
1231// loop does not exceed this limit before incrementing.
1232static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1233 ICmpInst::Predicate *Pred,
1234 ScalarEvolution *SE) {
1235 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1236 if (SE->isKnownPositive(Step)) {
1237 *Pred = ICmpInst::ICMP_SLT;
1239 SE->getSignedRangeMax(Step));
1240 }
1241 if (SE->isKnownNegative(Step)) {
1242 *Pred = ICmpInst::ICMP_SGT;
1244 SE->getSignedRangeMin(Step));
1245 }
1246 return nullptr;
1247}
1248
1249// Get the limit of a recurrence such that incrementing by Step cannot cause
1250// unsigned overflow as long as the value of the recurrence within the loop does
1251// not exceed this limit before incrementing.
1253 ICmpInst::Predicate *Pred,
1254 ScalarEvolution *SE) {
1255 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1256 *Pred = ICmpInst::ICMP_ULT;
1257
1259 SE->getUnsignedRangeMax(Step));
1260}
1261
1262namespace {
1263
1264struct ExtendOpTraitsBase {
1265 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1266 unsigned);
1267};
1268
1269// Used to make code generic over signed and unsigned overflow.
1270template <typename ExtendOp> struct ExtendOpTraits {
1271 // Members present:
1272 //
1273 // static const SCEV::NoWrapFlags WrapType;
1274 //
1275 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1276 //
1277 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1278 // ICmpInst::Predicate *Pred,
1279 // ScalarEvolution *SE);
1280};
1281
1282template <>
1283struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1284 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1285
1286 static const GetExtendExprTy GetExtendExpr;
1287
1288 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1289 ICmpInst::Predicate *Pred,
1290 ScalarEvolution *SE) {
1291 return getSignedOverflowLimitForStep(Step, Pred, SE);
1292 }
1293};
1294
1295const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1297
1298template <>
1299struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1300 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1301
1302 static const GetExtendExprTy GetExtendExpr;
1303
1304 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1305 ICmpInst::Predicate *Pred,
1306 ScalarEvolution *SE) {
1307 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1308 }
1309};
1310
1311const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1313
1314} // end anonymous namespace
1315
1316// The recurrence AR has been shown to have no signed/unsigned wrap or something
1317// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1318// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1319// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1320// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1321// expression "Step + sext/zext(PreIncAR)" is congruent with
1322// "sext/zext(PostIncAR)"
1323template <typename ExtendOpTy>
1324static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1325 ScalarEvolution *SE, unsigned Depth) {
1326 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1327 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1328
1329 const Loop *L = AR->getLoop();
1330 const SCEV *Start = AR->getStart();
1331 const SCEV *Step = AR->getStepRecurrence(*SE);
1332
1333 // Check for a simple looking step prior to loop entry.
1334 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1335 if (!SA)
1336 return nullptr;
1337
1338 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1339 // subtraction is expensive. For this purpose, perform a quick and dirty
1340 // difference, by checking for Step in the operand list. Note, that
1341 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1343 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1344 if (*It == Step) {
1345 DiffOps.erase(It);
1346 break;
1347 }
1348
1349 if (DiffOps.size() == SA->getNumOperands())
1350 return nullptr;
1351
1352 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1353 // `Step`:
1354
1355 // 1. NSW/NUW flags on the step increment.
1356 auto PreStartFlags =
1358 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1359 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1360 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1361
1362 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1363 // "S+X does not sign/unsign-overflow".
1364 //
1365
1366 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1367 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1368 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1369 return PreStart;
1370
1371 // 2. Direct overflow check on the step operation's expression.
1372 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1373 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1374 const SCEV *OperandExtendedStart =
1375 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1376 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1377 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1378 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1379 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1380 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1381 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1382 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1383 }
1384 return PreStart;
1385 }
1386
1387 // 3. Loop precondition.
1389 const SCEV *OverflowLimit =
1390 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1391
1392 if (OverflowLimit &&
1393 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1394 return PreStart;
1395
1396 return nullptr;
1397}
1398
1399// Get the normalized zero or sign extended expression for this AddRec's Start.
1400template <typename ExtendOpTy>
1401static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1402 ScalarEvolution *SE,
1403 unsigned Depth) {
1404 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1405
1406 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1407 if (!PreStart)
1408 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1409
1410 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1411 Depth),
1412 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1413}
1414
1415// Try to prove away overflow by looking at "nearby" add recurrences. A
1416// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1417// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1418//
1419// Formally:
1420//
1421// {S,+,X} == {S-T,+,X} + T
1422// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1423//
1424// If ({S-T,+,X} + T) does not overflow ... (1)
1425//
1426// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1427//
1428// If {S-T,+,X} does not overflow ... (2)
1429//
1430// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1431// == {Ext(S-T)+Ext(T),+,Ext(X)}
1432//
1433// If (S-T)+T does not overflow ... (3)
1434//
1435// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1436// == {Ext(S),+,Ext(X)} == LHS
1437//
1438// Thus, if (1), (2) and (3) are true for some T, then
1439// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1440//
1441// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1442// does not overflow" restricted to the 0th iteration. Therefore we only need
1443// to check for (1) and (2).
1444//
1445// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1446// is `Delta` (defined below).
1447template <typename ExtendOpTy>
1448bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1449 const SCEV *Step,
1450 const Loop *L) {
1451 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1452
1453 // We restrict `Start` to a constant to prevent SCEV from spending too much
1454 // time here. It is correct (but more expensive) to continue with a
1455 // non-constant `Start` and do a general SCEV subtraction to compute
1456 // `PreStart` below.
1457 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1458 if (!StartC)
1459 return false;
1460
1461 APInt StartAI = StartC->getAPInt();
1462
1463 for (unsigned Delta : {-2, -1, 1, 2}) {
1464 const SCEV *PreStart = getConstant(StartAI - Delta);
1465
1467 ID.AddInteger(scAddRecExpr);
1468 ID.AddPointer(PreStart);
1469 ID.AddPointer(Step);
1470 ID.AddPointer(L);
1471 void *IP = nullptr;
1472 const auto *PreAR =
1473 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1474
1475 // Give up if we don't already have the add recurrence we need because
1476 // actually constructing an add recurrence is relatively expensive.
1477 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1478 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1480 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1481 DeltaS, &Pred, this);
1482 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1483 return true;
1484 }
1485 }
1486
1487 return false;
1488}
1489
1490// Finds an integer D for an expression (C + x + y + ...) such that the top
1491// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1492// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1493// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1494// the (C + x + y + ...) expression is \p WholeAddExpr.
1496 const SCEVConstant *ConstantTerm,
1497 const SCEVAddExpr *WholeAddExpr) {
1498 const APInt &C = ConstantTerm->getAPInt();
1499 const unsigned BitWidth = C.getBitWidth();
1500 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1501 uint32_t TZ = BitWidth;
1502 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1503 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1504 if (TZ) {
1505 // Set D to be as many least significant bits of C as possible while still
1506 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1507 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1508 }
1509 return APInt(BitWidth, 0);
1510}
1511
1512// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1513// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1514// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1515// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1517 const APInt &ConstantStart,
1518 const SCEV *Step) {
1519 const unsigned BitWidth = ConstantStart.getBitWidth();
1520 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1521 if (TZ)
1522 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1523 : ConstantStart;
1524 return APInt(BitWidth, 0);
1525}
1526
1528 const ScalarEvolution::FoldID &ID, const SCEV *S,
1531 &FoldCacheUser) {
1532 auto I = FoldCache.insert({ID, S});
1533 if (!I.second) {
1534 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1535 // entry.
1536 auto &UserIDs = FoldCacheUser[I.first->second];
1537 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1538 for (unsigned I = 0; I != UserIDs.size(); ++I)
1539 if (UserIDs[I] == ID) {
1540 std::swap(UserIDs[I], UserIDs.back());
1541 break;
1542 }
1543 UserIDs.pop_back();
1544 I.first->second = S;
1545 }
1546 auto R = FoldCacheUser.insert({S, {}});
1547 R.first->second.push_back(ID);
1548}
1549
1550const SCEV *
1552 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1553 "This is not an extending conversion!");
1554 assert(isSCEVable(Ty) &&
1555 "This is not a conversion to a SCEVable type!");
1556 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1557 Ty = getEffectiveSCEVType(Ty);
1558
1559 FoldID ID(scZeroExtend, Op, Ty);
1560 auto Iter = FoldCache.find(ID);
1561 if (Iter != FoldCache.end())
1562 return Iter->second;
1563
1564 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1565 if (!isa<SCEVZeroExtendExpr>(S))
1566 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1567 return S;
1568}
1569
1571 unsigned Depth) {
1572 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1573 "This is not an extending conversion!");
1574 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1575 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1576
1577 // Fold if the operand is constant.
1578 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1579 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1580
1581 // zext(zext(x)) --> zext(x)
1582 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1583 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1584
1585 // Before doing any expensive analysis, check to see if we've already
1586 // computed a SCEV for this Op and Ty.
1588 ID.AddInteger(scZeroExtend);
1589 ID.AddPointer(Op);
1590 ID.AddPointer(Ty);
1591 void *IP = nullptr;
1592 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1593 if (Depth > MaxCastDepth) {
1594 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1595 Op, Ty);
1596 UniqueSCEVs.InsertNode(S, IP);
1597 registerUser(S, Op);
1598 return S;
1599 }
1600
1601 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1602 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1603 // It's possible the bits taken off by the truncate were all zero bits. If
1604 // so, we should be able to simplify this further.
1605 const SCEV *X = ST->getOperand();
1607 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1608 unsigned NewBits = getTypeSizeInBits(Ty);
1609 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1610 CR.zextOrTrunc(NewBits)))
1611 return getTruncateOrZeroExtend(X, Ty, Depth);
1612 }
1613
1614 // If the input value is a chrec scev, and we can prove that the value
1615 // did not overflow the old, smaller, value, we can zero extend all of the
1616 // operands (often constants). This allows analysis of something like
1617 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1618 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1619 if (AR->isAffine()) {
1620 const SCEV *Start = AR->getStart();
1621 const SCEV *Step = AR->getStepRecurrence(*this);
1622 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1623 const Loop *L = AR->getLoop();
1624
1625 // If we have special knowledge that this addrec won't overflow,
1626 // we don't need to do any further analysis.
1627 if (AR->hasNoUnsignedWrap()) {
1628 Start =
1629 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1630 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1631 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1632 }
1633
1634 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1635 // Note that this serves two purposes: It filters out loops that are
1636 // simply not analyzable, and it covers the case where this code is
1637 // being called from within backedge-taken count analysis, such that
1638 // attempting to ask for the backedge-taken count would likely result
1639 // in infinite recursion. In the later case, the analysis code will
1640 // cope with a conservative value, and it will take care to purge
1641 // that value once it has finished.
1642 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1643 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1644 // Manually compute the final value for AR, checking for overflow.
1645
1646 // Check whether the backedge-taken count can be losslessly casted to
1647 // the addrec's type. The count is always unsigned.
1648 const SCEV *CastedMaxBECount =
1649 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1650 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1651 CastedMaxBECount, MaxBECount->getType(), Depth);
1652 if (MaxBECount == RecastedMaxBECount) {
1653 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1654 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1655 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1657 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1659 Depth + 1),
1660 WideTy, Depth + 1);
1661 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1662 const SCEV *WideMaxBECount =
1663 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1664 const SCEV *OperandExtendedAdd =
1665 getAddExpr(WideStart,
1666 getMulExpr(WideMaxBECount,
1667 getZeroExtendExpr(Step, WideTy, Depth + 1),
1670 if (ZAdd == OperandExtendedAdd) {
1671 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1672 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1673 // Return the expression with the addrec on the outside.
1674 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1675 Depth + 1);
1676 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1677 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1678 }
1679 // Similar to above, only this time treat the step value as signed.
1680 // This covers loops that count down.
1681 OperandExtendedAdd =
1682 getAddExpr(WideStart,
1683 getMulExpr(WideMaxBECount,
1684 getSignExtendExpr(Step, WideTy, Depth + 1),
1687 if (ZAdd == OperandExtendedAdd) {
1688 // Cache knowledge of AR NW, which is propagated to this AddRec.
1689 // Negative step causes unsigned wrap, but it still can't self-wrap.
1690 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1691 // Return the expression with the addrec on the outside.
1692 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1693 Depth + 1);
1694 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1695 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1696 }
1697 }
1698 }
1699
1700 // Normally, in the cases we can prove no-overflow via a
1701 // backedge guarding condition, we can also compute a backedge
1702 // taken count for the loop. The exceptions are assumptions and
1703 // guards present in the loop -- SCEV is not great at exploiting
1704 // these to compute max backedge taken counts, but can still use
1705 // these to prove lack of overflow. Use this fact to avoid
1706 // doing extra work that may not pay off.
1707 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1708 !AC.assumptions().empty()) {
1709
1710 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1711 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1712 if (AR->hasNoUnsignedWrap()) {
1713 // Same as nuw case above - duplicated here to avoid a compile time
1714 // issue. It's not clear that the order of checks does matter, but
1715 // it's one of two issue possible causes for a change which was
1716 // reverted. Be conservative for the moment.
1717 Start =
1718 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1719 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1720 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1721 }
1722
1723 // For a negative step, we can extend the operands iff doing so only
1724 // traverses values in the range zext([0,UINT_MAX]).
1725 if (isKnownNegative(Step)) {
1727 getSignedRangeMin(Step));
1730 // Cache knowledge of AR NW, which is propagated to this
1731 // AddRec. Negative step causes unsigned wrap, but it
1732 // still can't self-wrap.
1733 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1734 // Return the expression with the addrec on the outside.
1735 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1736 Depth + 1);
1737 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1738 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1739 }
1740 }
1741 }
1742
1743 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1744 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1745 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1746 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1747 const APInt &C = SC->getAPInt();
1748 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1749 if (D != 0) {
1750 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1751 const SCEV *SResidual =
1752 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1753 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1754 return getAddExpr(SZExtD, SZExtR,
1756 Depth + 1);
1757 }
1758 }
1759
1760 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1761 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1762 Start =
1763 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1764 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1765 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1766 }
1767 }
1768
1769 // zext(A % B) --> zext(A) % zext(B)
1770 {
1771 const SCEV *LHS;
1772 const SCEV *RHS;
1773 if (matchURem(Op, LHS, RHS))
1774 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1775 getZeroExtendExpr(RHS, Ty, Depth + 1));
1776 }
1777
1778 // zext(A / B) --> zext(A) / zext(B).
1779 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1780 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1781 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1782
1783 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1784 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1785 if (SA->hasNoUnsignedWrap()) {
1786 // If the addition does not unsign overflow then we can, by definition,
1787 // commute the zero extension with the addition operation.
1789 for (const auto *Op : SA->operands())
1790 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1791 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1792 }
1793
1794 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1795 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1796 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1797 //
1798 // Often address arithmetics contain expressions like
1799 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1800 // This transformation is useful while proving that such expressions are
1801 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1802 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1803 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1804 if (D != 0) {
1805 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1806 const SCEV *SResidual =
1808 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1809 return getAddExpr(SZExtD, SZExtR,
1811 Depth + 1);
1812 }
1813 }
1814 }
1815
1816 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1817 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1818 if (SM->hasNoUnsignedWrap()) {
1819 // If the multiply does not unsign overflow then we can, by definition,
1820 // commute the zero extension with the multiply operation.
1822 for (const auto *Op : SM->operands())
1823 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1824 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1825 }
1826
1827 // zext(2^K * (trunc X to iN)) to iM ->
1828 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1829 //
1830 // Proof:
1831 //
1832 // zext(2^K * (trunc X to iN)) to iM
1833 // = zext((trunc X to iN) << K) to iM
1834 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1835 // (because shl removes the top K bits)
1836 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1837 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1838 //
1839 if (SM->getNumOperands() == 2)
1840 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1841 if (MulLHS->getAPInt().isPowerOf2())
1842 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1843 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1844 MulLHS->getAPInt().logBase2();
1845 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1846 return getMulExpr(
1847 getZeroExtendExpr(MulLHS, Ty),
1849 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1850 SCEV::FlagNUW, Depth + 1);
1851 }
1852 }
1853
1854 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1855 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1856 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1857 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1859 for (auto *Operand : MinMax->operands())
1860 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1861 if (isa<SCEVUMinExpr>(MinMax))
1862 return getUMinExpr(Operands);
1863 return getUMaxExpr(Operands);
1864 }
1865
1866 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1867 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1868 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1870 for (auto *Operand : MinMax->operands())
1871 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1872 return getUMinExpr(Operands, /*Sequential*/ true);
1873 }
1874
1875 // The cast wasn't folded; create an explicit cast node.
1876 // Recompute the insert position, as it may have been invalidated.
1877 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1878 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1879 Op, Ty);
1880 UniqueSCEVs.InsertNode(S, IP);
1881 registerUser(S, Op);
1882 return S;
1883}
1884
1885const SCEV *
1887 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1888 "This is not an extending conversion!");
1889 assert(isSCEVable(Ty) &&
1890 "This is not a conversion to a SCEVable type!");
1891 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1892 Ty = getEffectiveSCEVType(Ty);
1893
1894 FoldID ID(scSignExtend, Op, Ty);
1895 auto Iter = FoldCache.find(ID);
1896 if (Iter != FoldCache.end())
1897 return Iter->second;
1898
1899 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1900 if (!isa<SCEVSignExtendExpr>(S))
1901 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1902 return S;
1903}
1904
1906 unsigned Depth) {
1907 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1908 "This is not an extending conversion!");
1909 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1910 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1911 Ty = getEffectiveSCEVType(Ty);
1912
1913 // Fold if the operand is constant.
1914 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1915 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1916
1917 // sext(sext(x)) --> sext(x)
1918 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1919 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1920
1921 // sext(zext(x)) --> zext(x)
1922 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1923 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1924
1925 // Before doing any expensive analysis, check to see if we've already
1926 // computed a SCEV for this Op and Ty.
1928 ID.AddInteger(scSignExtend);
1929 ID.AddPointer(Op);
1930 ID.AddPointer(Ty);
1931 void *IP = nullptr;
1932 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1933 // Limit recursion depth.
1934 if (Depth > MaxCastDepth) {
1935 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1936 Op, Ty);
1937 UniqueSCEVs.InsertNode(S, IP);
1938 registerUser(S, Op);
1939 return S;
1940 }
1941
1942 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1943 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1944 // It's possible the bits taken off by the truncate were all sign bits. If
1945 // so, we should be able to simplify this further.
1946 const SCEV *X = ST->getOperand();
1948 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1949 unsigned NewBits = getTypeSizeInBits(Ty);
1950 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1951 CR.sextOrTrunc(NewBits)))
1952 return getTruncateOrSignExtend(X, Ty, Depth);
1953 }
1954
1955 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1956 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1957 if (SA->hasNoSignedWrap()) {
1958 // If the addition does not sign overflow then we can, by definition,
1959 // commute the sign extension with the addition operation.
1961 for (const auto *Op : SA->operands())
1962 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1963 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1964 }
1965
1966 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1967 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1968 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1969 //
1970 // For instance, this will bring two seemingly different expressions:
1971 // 1 + sext(5 + 20 * %x + 24 * %y) and
1972 // sext(6 + 20 * %x + 24 * %y)
1973 // to the same form:
1974 // 2 + sext(4 + 20 * %x + 24 * %y)
1975 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1976 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1977 if (D != 0) {
1978 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1979 const SCEV *SResidual =
1981 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1982 return getAddExpr(SSExtD, SSExtR,
1984 Depth + 1);
1985 }
1986 }
1987 }
1988 // If the input value is a chrec scev, and we can prove that the value
1989 // did not overflow the old, smaller, value, we can sign extend all of the
1990 // operands (often constants). This allows analysis of something like
1991 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1992 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1993 if (AR->isAffine()) {
1994 const SCEV *Start = AR->getStart();
1995 const SCEV *Step = AR->getStepRecurrence(*this);
1996 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1997 const Loop *L = AR->getLoop();
1998
1999 // If we have special knowledge that this addrec won't overflow,
2000 // we don't need to do any further analysis.
2001 if (AR->hasNoSignedWrap()) {
2002 Start =
2003 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2004 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2005 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2006 }
2007
2008 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2009 // Note that this serves two purposes: It filters out loops that are
2010 // simply not analyzable, and it covers the case where this code is
2011 // being called from within backedge-taken count analysis, such that
2012 // attempting to ask for the backedge-taken count would likely result
2013 // in infinite recursion. In the later case, the analysis code will
2014 // cope with a conservative value, and it will take care to purge
2015 // that value once it has finished.
2016 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2017 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2018 // Manually compute the final value for AR, checking for
2019 // overflow.
2020
2021 // Check whether the backedge-taken count can be losslessly casted to
2022 // the addrec's type. The count is always unsigned.
2023 const SCEV *CastedMaxBECount =
2024 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2025 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2026 CastedMaxBECount, MaxBECount->getType(), Depth);
2027 if (MaxBECount == RecastedMaxBECount) {
2028 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2029 // Check whether Start+Step*MaxBECount has no signed overflow.
2030 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2032 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2034 Depth + 1),
2035 WideTy, Depth + 1);
2036 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2037 const SCEV *WideMaxBECount =
2038 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2039 const SCEV *OperandExtendedAdd =
2040 getAddExpr(WideStart,
2041 getMulExpr(WideMaxBECount,
2042 getSignExtendExpr(Step, WideTy, Depth + 1),
2045 if (SAdd == OperandExtendedAdd) {
2046 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2047 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2048 // Return the expression with the addrec on the outside.
2049 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2050 Depth + 1);
2051 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2052 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2053 }
2054 // Similar to above, only this time treat the step value as unsigned.
2055 // This covers loops that count up with an unsigned step.
2056 OperandExtendedAdd =
2057 getAddExpr(WideStart,
2058 getMulExpr(WideMaxBECount,
2059 getZeroExtendExpr(Step, WideTy, Depth + 1),
2062 if (SAdd == OperandExtendedAdd) {
2063 // If AR wraps around then
2064 //
2065 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2066 // => SAdd != OperandExtendedAdd
2067 //
2068 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2069 // (SAdd == OperandExtendedAdd => AR is NW)
2070
2071 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2072
2073 // Return the expression with the addrec on the outside.
2074 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2075 Depth + 1);
2076 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2077 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2078 }
2079 }
2080 }
2081
2082 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2083 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2084 if (AR->hasNoSignedWrap()) {
2085 // Same as nsw case above - duplicated here to avoid a compile time
2086 // issue. It's not clear that the order of checks does matter, but
2087 // it's one of two issue possible causes for a change which was
2088 // reverted. Be conservative for the moment.
2089 Start =
2090 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2091 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2092 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2093 }
2094
2095 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2096 // if D + (C - D + Step * n) could be proven to not signed wrap
2097 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2098 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2099 const APInt &C = SC->getAPInt();
2100 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2101 if (D != 0) {
2102 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2103 const SCEV *SResidual =
2104 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2105 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2106 return getAddExpr(SSExtD, SSExtR,
2108 Depth + 1);
2109 }
2110 }
2111
2112 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2113 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2114 Start =
2115 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2116 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2117 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2118 }
2119 }
2120
2121 // If the input value is provably positive and we could not simplify
2122 // away the sext build a zext instead.
2124 return getZeroExtendExpr(Op, Ty, Depth + 1);
2125
2126 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2127 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2128 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2129 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2131 for (auto *Operand : MinMax->operands())
2132 Operands.push_back(getSignExtendExpr(Operand, Ty));
2133 if (isa<SCEVSMinExpr>(MinMax))
2134 return getSMinExpr(Operands);
2135 return getSMaxExpr(Operands);
2136 }
2137
2138 // The cast wasn't folded; create an explicit cast node.
2139 // Recompute the insert position, as it may have been invalidated.
2140 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2141 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2142 Op, Ty);
2143 UniqueSCEVs.InsertNode(S, IP);
2144 registerUser(S, { Op });
2145 return S;
2146}
2147
2149 Type *Ty) {
2150 switch (Kind) {
2151 case scTruncate:
2152 return getTruncateExpr(Op, Ty);
2153 case scZeroExtend:
2154 return getZeroExtendExpr(Op, Ty);
2155 case scSignExtend:
2156 return getSignExtendExpr(Op, Ty);
2157 case scPtrToInt:
2158 return getPtrToIntExpr(Op, Ty);
2159 default:
2160 llvm_unreachable("Not a SCEV cast expression!");
2161 }
2162}
2163
2164/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2165/// unspecified bits out to the given type.
2167 Type *Ty) {
2168 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2169 "This is not an extending conversion!");
2170 assert(isSCEVable(Ty) &&
2171 "This is not a conversion to a SCEVable type!");
2172 Ty = getEffectiveSCEVType(Ty);
2173
2174 // Sign-extend negative constants.
2175 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2176 if (SC->getAPInt().isNegative())
2177 return getSignExtendExpr(Op, Ty);
2178
2179 // Peel off a truncate cast.
2180 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2181 const SCEV *NewOp = T->getOperand();
2182 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2183 return getAnyExtendExpr(NewOp, Ty);
2184 return getTruncateOrNoop(NewOp, Ty);
2185 }
2186
2187 // Next try a zext cast. If the cast is folded, use it.
2188 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2189 if (!isa<SCEVZeroExtendExpr>(ZExt))
2190 return ZExt;
2191
2192 // Next try a sext cast. If the cast is folded, use it.
2193 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2194 if (!isa<SCEVSignExtendExpr>(SExt))
2195 return SExt;
2196
2197 // Force the cast to be folded into the operands of an addrec.
2198 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2200 for (const SCEV *Op : AR->operands())
2201 Ops.push_back(getAnyExtendExpr(Op, Ty));
2202 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2203 }
2204
2205 // If the expression is obviously signed, use the sext cast value.
2206 if (isa<SCEVSMaxExpr>(Op))
2207 return SExt;
2208
2209 // Absent any other information, use the zext cast value.
2210 return ZExt;
2211}
2212
2213/// Process the given Ops list, which is a list of operands to be added under
2214/// the given scale, update the given map. This is a helper function for
2215/// getAddRecExpr. As an example of what it does, given a sequence of operands
2216/// that would form an add expression like this:
2217///
2218/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2219///
2220/// where A and B are constants, update the map with these values:
2221///
2222/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2223///
2224/// and add 13 + A*B*29 to AccumulatedConstant.
2225/// This will allow getAddRecExpr to produce this:
2226///
2227/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2228///
2229/// This form often exposes folding opportunities that are hidden in
2230/// the original operand list.
2231///
2232/// Return true iff it appears that any interesting folding opportunities
2233/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2234/// the common case where no interesting opportunities are present, and
2235/// is also used as a check to avoid infinite recursion.
2236static bool
2239 APInt &AccumulatedConstant,
2240 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2241 ScalarEvolution &SE) {
2242 bool Interesting = false;
2243
2244 // Iterate over the add operands. They are sorted, with constants first.
2245 unsigned i = 0;
2246 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2247 ++i;
2248 // Pull a buried constant out to the outside.
2249 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2250 Interesting = true;
2251 AccumulatedConstant += Scale * C->getAPInt();
2252 }
2253
2254 // Next comes everything else. We're especially interested in multiplies
2255 // here, but they're in the middle, so just visit the rest with one loop.
2256 for (; i != Ops.size(); ++i) {
2257 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2258 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2259 APInt NewScale =
2260 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2261 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2262 // A multiplication of a constant with another add; recurse.
2263 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2264 Interesting |=
2265 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2266 Add->operands(), NewScale, SE);
2267 } else {
2268 // A multiplication of a constant with some other value. Update
2269 // the map.
2270 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2271 const SCEV *Key = SE.getMulExpr(MulOps);
2272 auto Pair = M.insert({Key, NewScale});
2273 if (Pair.second) {
2274 NewOps.push_back(Pair.first->first);
2275 } else {
2276 Pair.first->second += NewScale;
2277 // The map already had an entry for this value, which may indicate
2278 // a folding opportunity.
2279 Interesting = true;
2280 }
2281 }
2282 } else {
2283 // An ordinary operand. Update the map.
2284 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2285 M.insert({Ops[i], Scale});
2286 if (Pair.second) {
2287 NewOps.push_back(Pair.first->first);
2288 } else {
2289 Pair.first->second += Scale;
2290 // The map already had an entry for this value, which may indicate
2291 // a folding opportunity.
2292 Interesting = true;
2293 }
2294 }
2295 }
2296
2297 return Interesting;
2298}
2299
2301 const SCEV *LHS, const SCEV *RHS,
2302 const Instruction *CtxI) {
2303 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2304 SCEV::NoWrapFlags, unsigned);
2305 switch (BinOp) {
2306 default:
2307 llvm_unreachable("Unsupported binary op");
2308 case Instruction::Add:
2310 break;
2311 case Instruction::Sub:
2313 break;
2314 case Instruction::Mul:
2316 break;
2317 }
2318
2319 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2322
2323 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2324 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2325 auto *WideTy =
2326 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2327
2328 const SCEV *A = (this->*Extension)(
2329 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2330 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2331 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2332 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2333 if (A == B)
2334 return true;
2335 // Can we use context to prove the fact we need?
2336 if (!CtxI)
2337 return false;
2338 // TODO: Support mul.
2339 if (BinOp == Instruction::Mul)
2340 return false;
2341 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2342 // TODO: Lift this limitation.
2343 if (!RHSC)
2344 return false;
2345 APInt C = RHSC->getAPInt();
2346 unsigned NumBits = C.getBitWidth();
2347 bool IsSub = (BinOp == Instruction::Sub);
2348 bool IsNegativeConst = (Signed && C.isNegative());
2349 // Compute the direction and magnitude by which we need to check overflow.
2350 bool OverflowDown = IsSub ^ IsNegativeConst;
2351 APInt Magnitude = C;
2352 if (IsNegativeConst) {
2353 if (C == APInt::getSignedMinValue(NumBits))
2354 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2355 // want to deal with that.
2356 return false;
2357 Magnitude = -C;
2358 }
2359
2361 if (OverflowDown) {
2362 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2363 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2364 : APInt::getMinValue(NumBits);
2365 APInt Limit = Min + Magnitude;
2366 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2367 } else {
2368 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2369 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2370 : APInt::getMaxValue(NumBits);
2371 APInt Limit = Max - Magnitude;
2372 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2373 }
2374}
2375
2376std::optional<SCEV::NoWrapFlags>
2378 const OverflowingBinaryOperator *OBO) {
2379 // It cannot be done any better.
2380 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2381 return std::nullopt;
2382
2384
2385 if (OBO->hasNoUnsignedWrap())
2387 if (OBO->hasNoSignedWrap())
2389
2390 bool Deduced = false;
2391
2392 if (OBO->getOpcode() != Instruction::Add &&
2393 OBO->getOpcode() != Instruction::Sub &&
2394 OBO->getOpcode() != Instruction::Mul)
2395 return std::nullopt;
2396
2397 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2398 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2399
2400 const Instruction *CtxI =
2401 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2402 if (!OBO->hasNoUnsignedWrap() &&
2404 /* Signed */ false, LHS, RHS, CtxI)) {
2406 Deduced = true;
2407 }
2408
2409 if (!OBO->hasNoSignedWrap() &&
2411 /* Signed */ true, LHS, RHS, CtxI)) {
2413 Deduced = true;
2414 }
2415
2416 if (Deduced)
2417 return Flags;
2418 return std::nullopt;
2419}
2420
2421// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2422// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2423// can't-overflow flags for the operation if possible.
2424static SCEV::NoWrapFlags
2426 const ArrayRef<const SCEV *> Ops,
2427 SCEV::NoWrapFlags Flags) {
2428 using namespace std::placeholders;
2429
2430 using OBO = OverflowingBinaryOperator;
2431
2432 bool CanAnalyze =
2434 (void)CanAnalyze;
2435 assert(CanAnalyze && "don't call from other places!");
2436
2437 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2438 SCEV::NoWrapFlags SignOrUnsignWrap =
2439 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2440
2441 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2442 auto IsKnownNonNegative = [&](const SCEV *S) {
2443 return SE->isKnownNonNegative(S);
2444 };
2445
2446 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2447 Flags =
2448 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2449
2450 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2451
2452 if (SignOrUnsignWrap != SignOrUnsignMask &&
2453 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2454 isa<SCEVConstant>(Ops[0])) {
2455
2456 auto Opcode = [&] {
2457 switch (Type) {
2458 case scAddExpr:
2459 return Instruction::Add;
2460 case scMulExpr:
2461 return Instruction::Mul;
2462 default:
2463 llvm_unreachable("Unexpected SCEV op.");
2464 }
2465 }();
2466
2467 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2468
2469 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2470 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2472 Opcode, C, OBO::NoSignedWrap);
2473 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2475 }
2476
2477 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2478 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2480 Opcode, C, OBO::NoUnsignedWrap);
2481 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2483 }
2484 }
2485
2486 // <0,+,nonnegative><nw> is also nuw
2487 // TODO: Add corresponding nsw case
2489 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2490 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2492
2493 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2495 Ops.size() == 2) {
2496 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2497 if (UDiv->getOperand(1) == Ops[1])
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2500 if (UDiv->getOperand(1) == Ops[0])
2502 }
2503
2504 return Flags;
2505}
2506
2508 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2509}
2510
2511/// Get a canonical add expression, or something simpler if possible.
2513 SCEV::NoWrapFlags OrigFlags,
2514 unsigned Depth) {
2515 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2516 "only nuw or nsw allowed");
2517 assert(!Ops.empty() && "Cannot get empty add!");
2518 if (Ops.size() == 1) return Ops[0];
2519#ifndef NDEBUG
2520 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2521 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2522 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2523 "SCEVAddExpr operand types don't match!");
2524 unsigned NumPtrs = count_if(
2525 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2526 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2527#endif
2528
2529 // Sort by complexity, this groups all similar expression types together.
2530 GroupByComplexity(Ops, &LI, DT);
2531
2532 // If there are any constants, fold them together.
2533 unsigned Idx = 0;
2534 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2535 ++Idx;
2536 assert(Idx < Ops.size());
2537 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2538 // We found two constants, fold them together!
2539 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2540 if (Ops.size() == 2) return Ops[0];
2541 Ops.erase(Ops.begin()+1); // Erase the folded element
2542 LHSC = cast<SCEVConstant>(Ops[0]);
2543 }
2544
2545 // If we are left with a constant zero being added, strip it off.
2546 if (LHSC->getValue()->isZero()) {
2547 Ops.erase(Ops.begin());
2548 --Idx;
2549 }
2550
2551 if (Ops.size() == 1) return Ops[0];
2552 }
2553
2554 // Delay expensive flag strengthening until necessary.
2555 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2556 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2557 };
2558
2559 // Limit recursion calls depth.
2561 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2562
2563 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2564 // Don't strengthen flags if we have no new information.
2565 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2566 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2567 Add->setNoWrapFlags(ComputeFlags(Ops));
2568 return S;
2569 }
2570
2571 // Okay, check to see if the same value occurs in the operand list more than
2572 // once. If so, merge them together into an multiply expression. Since we
2573 // sorted the list, these values are required to be adjacent.
2574 Type *Ty = Ops[0]->getType();
2575 bool FoundMatch = false;
2576 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2577 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2578 // Scan ahead to count how many equal operands there are.
2579 unsigned Count = 2;
2580 while (i+Count != e && Ops[i+Count] == Ops[i])
2581 ++Count;
2582 // Merge the values into a multiply.
2583 const SCEV *Scale = getConstant(Ty, Count);
2584 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2585 if (Ops.size() == Count)
2586 return Mul;
2587 Ops[i] = Mul;
2588 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2589 --i; e -= Count - 1;
2590 FoundMatch = true;
2591 }
2592 if (FoundMatch)
2593 return getAddExpr(Ops, OrigFlags, Depth + 1);
2594
2595 // Check for truncates. If all the operands are truncated from the same
2596 // type, see if factoring out the truncate would permit the result to be
2597 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2598 // if the contents of the resulting outer trunc fold to something simple.
2599 auto FindTruncSrcType = [&]() -> Type * {
2600 // We're ultimately looking to fold an addrec of truncs and muls of only
2601 // constants and truncs, so if we find any other types of SCEV
2602 // as operands of the addrec then we bail and return nullptr here.
2603 // Otherwise, we return the type of the operand of a trunc that we find.
2604 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2605 return T->getOperand()->getType();
2606 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2607 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2608 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2609 return T->getOperand()->getType();
2610 }
2611 return nullptr;
2612 };
2613 if (auto *SrcType = FindTruncSrcType()) {
2615 bool Ok = true;
2616 // Check all the operands to see if they can be represented in the
2617 // source type of the truncate.
2618 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2619 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeOps.push_back(T->getOperand());
2625 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2626 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2628 SmallVector<const SCEV *, 8> LargeMulOps;
2629 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2630 if (const SCEVTruncateExpr *T =
2631 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2632 if (T->getOperand()->getType() != SrcType) {
2633 Ok = false;
2634 break;
2635 }
2636 LargeMulOps.push_back(T->getOperand());
2637 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2638 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2639 } else {
2640 Ok = false;
2641 break;
2642 }
2643 }
2644 if (Ok)
2645 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2646 } else {
2647 Ok = false;
2648 break;
2649 }
2650 }
2651 if (Ok) {
2652 // Evaluate the expression in the larger type.
2653 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2654 // If it folds to something simple, use it. Otherwise, don't.
2655 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2656 return getTruncateExpr(Fold, Ty);
2657 }
2658 }
2659
2660 if (Ops.size() == 2) {
2661 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2662 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2663 // C1).
2664 const SCEV *A = Ops[0];
2665 const SCEV *B = Ops[1];
2666 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2667 auto *C = dyn_cast<SCEVConstant>(A);
2668 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2669 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2670 auto C2 = C->getAPInt();
2671 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2672
2673 APInt ConstAdd = C1 + C2;
2674 auto AddFlags = AddExpr->getNoWrapFlags();
2675 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2677 ConstAdd.ule(C1)) {
2678 PreservedFlags =
2680 }
2681
2682 // Adding a constant with the same sign and small magnitude is NSW, if the
2683 // original AddExpr was NSW.
2685 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2686 ConstAdd.abs().ule(C1.abs())) {
2687 PreservedFlags =
2689 }
2690
2691 if (PreservedFlags != SCEV::FlagAnyWrap) {
2692 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2693 NewOps[0] = getConstant(ConstAdd);
2694 return getAddExpr(NewOps, PreservedFlags);
2695 }
2696 }
2697 }
2698
2699 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2700 if (Ops.size() == 2) {
2701 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2702 if (Mul && Mul->getNumOperands() == 2 &&
2703 Mul->getOperand(0)->isAllOnesValue()) {
2704 const SCEV *X;
2705 const SCEV *Y;
2706 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2707 return getMulExpr(Y, getUDivExpr(X, Y));
2708 }
2709 }
2710 }
2711
2712 // Skip past any other cast SCEVs.
2713 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2714 ++Idx;
2715
2716 // If there are add operands they would be next.
2717 if (Idx < Ops.size()) {
2718 bool DeletedAdd = false;
2719 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2720 // common NUW flag for expression after inlining. Other flags cannot be
2721 // preserved, because they may depend on the original order of operations.
2722 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2723 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2724 if (Ops.size() > AddOpsInlineThreshold ||
2725 Add->getNumOperands() > AddOpsInlineThreshold)
2726 break;
2727 // If we have an add, expand the add operands onto the end of the operands
2728 // list.
2729 Ops.erase(Ops.begin()+Idx);
2730 append_range(Ops, Add->operands());
2731 DeletedAdd = true;
2732 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2733 }
2734
2735 // If we deleted at least one add, we added operands to the end of the list,
2736 // and they are not necessarily sorted. Recurse to resort and resimplify
2737 // any operands we just acquired.
2738 if (DeletedAdd)
2739 return getAddExpr(Ops, CommonFlags, Depth + 1);
2740 }
2741
2742 // Skip over the add expression until we get to a multiply.
2743 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2744 ++Idx;
2745
2746 // Check to see if there are any folding opportunities present with
2747 // operands multiplied by constant values.
2748 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2752 APInt AccumulatedConstant(BitWidth, 0);
2753 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2754 Ops, APInt(BitWidth, 1), *this)) {
2755 struct APIntCompare {
2756 bool operator()(const APInt &LHS, const APInt &RHS) const {
2757 return LHS.ult(RHS);
2758 }
2759 };
2760
2761 // Some interesting folding opportunity is present, so its worthwhile to
2762 // re-generate the operands list. Group the operands by constant scale,
2763 // to avoid multiplying by the same constant scale multiple times.
2764 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2765 for (const SCEV *NewOp : NewOps)
2766 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2767 // Re-generate the operands list.
2768 Ops.clear();
2769 if (AccumulatedConstant != 0)
2770 Ops.push_back(getConstant(AccumulatedConstant));
2771 for (auto &MulOp : MulOpLists) {
2772 if (MulOp.first == 1) {
2773 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2774 } else if (MulOp.first != 0) {
2776 getConstant(MulOp.first),
2777 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2778 SCEV::FlagAnyWrap, Depth + 1));
2779 }
2780 }
2781 if (Ops.empty())
2782 return getZero(Ty);
2783 if (Ops.size() == 1)
2784 return Ops[0];
2785 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2786 }
2787 }
2788
2789 // If we are adding something to a multiply expression, make sure the
2790 // something is not already an operand of the multiply. If so, merge it into
2791 // the multiply.
2792 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2793 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2794 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2795 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2796 if (isa<SCEVConstant>(MulOpSCEV))
2797 continue;
2798 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2799 if (MulOpSCEV == Ops[AddOp]) {
2800 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2801 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2802 if (Mul->getNumOperands() != 2) {
2803 // If the multiply has more than two operands, we must get the
2804 // Y*Z term.
2806 Mul->operands().take_front(MulOp));
2807 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2808 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2809 }
2810 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2811 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2812 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2814 if (Ops.size() == 2) return OuterMul;
2815 if (AddOp < Idx) {
2816 Ops.erase(Ops.begin()+AddOp);
2817 Ops.erase(Ops.begin()+Idx-1);
2818 } else {
2819 Ops.erase(Ops.begin()+Idx);
2820 Ops.erase(Ops.begin()+AddOp-1);
2821 }
2822 Ops.push_back(OuterMul);
2823 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2824 }
2825
2826 // Check this multiply against other multiplies being added together.
2827 for (unsigned OtherMulIdx = Idx+1;
2828 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2829 ++OtherMulIdx) {
2830 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2831 // If MulOp occurs in OtherMul, we can fold the two multiplies
2832 // together.
2833 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2834 OMulOp != e; ++OMulOp)
2835 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2836 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2837 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2838 if (Mul->getNumOperands() != 2) {
2840 Mul->operands().take_front(MulOp));
2841 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2842 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2843 }
2844 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2845 if (OtherMul->getNumOperands() != 2) {
2847 OtherMul->operands().take_front(OMulOp));
2848 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2849 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2850 }
2851 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2852 const SCEV *InnerMulSum =
2853 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2854 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2856 if (Ops.size() == 2) return OuterMul;
2857 Ops.erase(Ops.begin()+Idx);
2858 Ops.erase(Ops.begin()+OtherMulIdx-1);
2859 Ops.push_back(OuterMul);
2860 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2861 }
2862 }
2863 }
2864 }
2865
2866 // If there are any add recurrences in the operands list, see if any other
2867 // added values are loop invariant. If so, we can fold them into the
2868 // recurrence.
2869 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2870 ++Idx;
2871
2872 // Scan over all recurrences, trying to fold loop invariants into them.
2873 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2874 // Scan all of the other operands to this add and add them to the vector if
2875 // they are loop invariant w.r.t. the recurrence.
2877 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2878 const Loop *AddRecLoop = AddRec->getLoop();
2879 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2880 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2881 LIOps.push_back(Ops[i]);
2882 Ops.erase(Ops.begin()+i);
2883 --i; --e;
2884 }
2885
2886 // If we found some loop invariants, fold them into the recurrence.
2887 if (!LIOps.empty()) {
2888 // Compute nowrap flags for the addition of the loop-invariant ops and
2889 // the addrec. Temporarily push it as an operand for that purpose. These
2890 // flags are valid in the scope of the addrec only.
2891 LIOps.push_back(AddRec);
2892 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2893 LIOps.pop_back();
2894
2895 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2896 LIOps.push_back(AddRec->getStart());
2897
2898 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2899
2900 // It is not in general safe to propagate flags valid on an add within
2901 // the addrec scope to one outside it. We must prove that the inner
2902 // scope is guaranteed to execute if the outer one does to be able to
2903 // safely propagate. We know the program is undefined if poison is
2904 // produced on the inner scoped addrec. We also know that *for this use*
2905 // the outer scoped add can't overflow (because of the flags we just
2906 // computed for the inner scoped add) without the program being undefined.
2907 // Proving that entry to the outer scope neccesitates entry to the inner
2908 // scope, thus proves the program undefined if the flags would be violated
2909 // in the outer scope.
2910 SCEV::NoWrapFlags AddFlags = Flags;
2911 if (AddFlags != SCEV::FlagAnyWrap) {
2912 auto *DefI = getDefiningScopeBound(LIOps);
2913 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2914 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2915 AddFlags = SCEV::FlagAnyWrap;
2916 }
2917 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2918
2919 // Build the new addrec. Propagate the NUW and NSW flags if both the
2920 // outer add and the inner addrec are guaranteed to have no overflow.
2921 // Always propagate NW.
2922 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2923 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2924
2925 // If all of the other operands were loop invariant, we are done.
2926 if (Ops.size() == 1) return NewRec;
2927
2928 // Otherwise, add the folded AddRec by the non-invariant parts.
2929 for (unsigned i = 0;; ++i)
2930 if (Ops[i] == AddRec) {
2931 Ops[i] = NewRec;
2932 break;
2933 }
2934 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2935 }
2936
2937 // Okay, if there weren't any loop invariants to be folded, check to see if
2938 // there are multiple AddRec's with the same loop induction variable being
2939 // added together. If so, we can fold them.
2940 for (unsigned OtherIdx = Idx+1;
2941 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2942 ++OtherIdx) {
2943 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2944 // so that the 1st found AddRecExpr is dominated by all others.
2945 assert(DT.dominates(
2946 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2947 AddRec->getLoop()->getHeader()) &&
2948 "AddRecExprs are not sorted in reverse dominance order?");
2949 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2950 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2951 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2952 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2953 ++OtherIdx) {
2954 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2955 if (OtherAddRec->getLoop() == AddRecLoop) {
2956 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2957 i != e; ++i) {
2958 if (i >= AddRecOps.size()) {
2959 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2960 break;
2961 }
2963 AddRecOps[i], OtherAddRec->getOperand(i)};
2964 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2965 }
2966 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2967 }
2968 }
2969 // Step size has changed, so we cannot guarantee no self-wraparound.
2970 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2971 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2972 }
2973 }
2974
2975 // Otherwise couldn't fold anything into this recurrence. Move onto the
2976 // next one.
2977 }
2978
2979 // Okay, it looks like we really DO need an add expr. Check to see if we
2980 // already have one, otherwise create a new one.
2981 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2982}
2983
2984const SCEV *
2985ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2986 SCEV::NoWrapFlags Flags) {
2988 ID.AddInteger(scAddExpr);
2989 for (const SCEV *Op : Ops)
2990 ID.AddPointer(Op);
2991 void *IP = nullptr;
2992 SCEVAddExpr *S =
2993 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2994 if (!S) {
2995 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2996 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2997 S = new (SCEVAllocator)
2998 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2999 UniqueSCEVs.InsertNode(S, IP);
3000 registerUser(S, Ops);
3001 }
3002 S->setNoWrapFlags(Flags);
3003 return S;
3004}
3005
3006const SCEV *
3007ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3008 const Loop *L, SCEV::NoWrapFlags Flags) {
3010 ID.AddInteger(scAddRecExpr);
3011 for (const SCEV *Op : Ops)
3012 ID.AddPointer(Op);
3013 ID.AddPointer(L);
3014 void *IP = nullptr;
3015 SCEVAddRecExpr *S =
3016 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3017 if (!S) {
3018 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3019 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3020 S = new (SCEVAllocator)
3021 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3022 UniqueSCEVs.InsertNode(S, IP);
3023 LoopUsers[L].push_back(S);
3024 registerUser(S, Ops);
3025 }
3026 setNoWrapFlags(S, Flags);
3027 return S;
3028}
3029
3030const SCEV *
3031ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3032 SCEV::NoWrapFlags Flags) {
3034 ID.AddInteger(scMulExpr);
3035 for (const SCEV *Op : Ops)
3036 ID.AddPointer(Op);
3037 void *IP = nullptr;
3038 SCEVMulExpr *S =
3039 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3040 if (!S) {
3041 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3042 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3043 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3044 O, Ops.size());
3045 UniqueSCEVs.InsertNode(S, IP);
3046 registerUser(S, Ops);
3047 }
3048 S->setNoWrapFlags(Flags);
3049 return S;
3050}
3051
3052static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3053 uint64_t k = i*j;
3054 if (j > 1 && k / j != i) Overflow = true;
3055 return k;
3056}
3057
3058/// Compute the result of "n choose k", the binomial coefficient. If an
3059/// intermediate computation overflows, Overflow will be set and the return will
3060/// be garbage. Overflow is not cleared on absence of overflow.
3061static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3062 // We use the multiplicative formula:
3063 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3064 // At each iteration, we take the n-th term of the numeral and divide by the
3065 // (k-n)th term of the denominator. This division will always produce an
3066 // integral result, and helps reduce the chance of overflow in the
3067 // intermediate computations. However, we can still overflow even when the
3068 // final result would fit.
3069
3070 if (n == 0 || n == k) return 1;
3071 if (k > n) return 0;
3072
3073 if (k > n/2)
3074 k = n-k;
3075
3076 uint64_t r = 1;
3077 for (uint64_t i = 1; i <= k; ++i) {
3078 r = umul_ov(r, n-(i-1), Overflow);
3079 r /= i;
3080 }
3081 return r;
3082}
3083
3084/// Determine if any of the operands in this SCEV are a constant or if
3085/// any of the add or multiply expressions in this SCEV contain a constant.
3086static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3087 struct FindConstantInAddMulChain {
3088 bool FoundConstant = false;
3089
3090 bool follow(const SCEV *S) {
3091 FoundConstant |= isa<SCEVConstant>(S);
3092 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3093 }
3094
3095 bool isDone() const {
3096 return FoundConstant;
3097 }
3098 };
3099
3100 FindConstantInAddMulChain F;
3102 ST.visitAll(StartExpr);
3103 return F.FoundConstant;
3104}
3105
3106/// Get a canonical multiply expression, or something simpler if possible.
3108 SCEV::NoWrapFlags OrigFlags,
3109 unsigned Depth) {
3110 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3111 "only nuw or nsw allowed");
3112 assert(!Ops.empty() && "Cannot get empty mul!");
3113 if (Ops.size() == 1) return Ops[0];
3114#ifndef NDEBUG
3115 Type *ETy = Ops[0]->getType();
3116 assert(!ETy->isPointerTy());
3117 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3118 assert(Ops[i]->getType() == ETy &&
3119 "SCEVMulExpr operand types don't match!");
3120#endif
3121
3122 // Sort by complexity, this groups all similar expression types together.
3123 GroupByComplexity(Ops, &LI, DT);
3124
3125 // If there are any constants, fold them together.
3126 unsigned Idx = 0;
3127 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3128 ++Idx;
3129 assert(Idx < Ops.size());
3130 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3131 // We found two constants, fold them together!
3132 Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3133 if (Ops.size() == 2) return Ops[0];
3134 Ops.erase(Ops.begin()+1); // Erase the folded element
3135 LHSC = cast<SCEVConstant>(Ops[0]);
3136 }
3137
3138 // If we have a multiply of zero, it will always be zero.
3139 if (LHSC->getValue()->isZero())
3140 return LHSC;
3141
3142 // If we are left with a constant one being multiplied, strip it off.
3143 if (LHSC->getValue()->isOne()) {
3144 Ops.erase(Ops.begin());
3145 --Idx;
3146 }
3147
3148 if (Ops.size() == 1)
3149 return Ops[0];
3150 }
3151
3152 // Delay expensive flag strengthening until necessary.
3153 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3154 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3155 };
3156
3157 // Limit recursion calls depth.
3159 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3160
3161 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3162 // Don't strengthen flags if we have no new information.
3163 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3164 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3165 Mul->setNoWrapFlags(ComputeFlags(Ops));
3166 return S;
3167 }
3168
3169 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3170 if (Ops.size() == 2) {
3171 // C1*(C2+V) -> C1*C2 + C1*V
3172 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3173 // If any of Add's ops are Adds or Muls with a constant, apply this
3174 // transformation as well.
3175 //
3176 // TODO: There are some cases where this transformation is not
3177 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3178 // this transformation should be narrowed down.
3179 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3180 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3182 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3184 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3185 }
3186
3187 if (Ops[0]->isAllOnesValue()) {
3188 // If we have a mul by -1 of an add, try distributing the -1 among the
3189 // add operands.
3190 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3192 bool AnyFolded = false;
3193 for (const SCEV *AddOp : Add->operands()) {
3194 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3195 Depth + 1);
3196 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3197 NewOps.push_back(Mul);
3198 }
3199 if (AnyFolded)
3200 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3201 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3202 // Negation preserves a recurrence's no self-wrap property.
3204 for (const SCEV *AddRecOp : AddRec->operands())
3205 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3206 Depth + 1));
3207 // Let M be the minimum representable signed value. AddRec with nsw
3208 // multiplied by -1 can have signed overflow if and only if it takes a
3209 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3210 // maximum signed value. In all other cases signed overflow is
3211 // impossible.
3212 auto FlagsMask = SCEV::FlagNW;
3213 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3214 auto MinInt =
3215 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3216 if (getSignedRangeMin(AddRec) != MinInt)
3217 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3218 }
3219 return getAddRecExpr(Operands, AddRec->getLoop(),
3220 AddRec->getNoWrapFlags(FlagsMask));
3221 }
3222 }
3223 }
3224 }
3225
3226 // Skip over the add expression until we get to a multiply.
3227 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3228 ++Idx;
3229
3230 // If there are mul operands inline them all into this expression.
3231 if (Idx < Ops.size()) {
3232 bool DeletedMul = false;
3233 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3234 if (Ops.size() > MulOpsInlineThreshold)
3235 break;
3236 // If we have an mul, expand the mul operands onto the end of the
3237 // operands list.
3238 Ops.erase(Ops.begin()+Idx);
3239 append_range(Ops, Mul->operands());
3240 DeletedMul = true;
3241 }
3242
3243 // If we deleted at least one mul, we added operands to the end of the
3244 // list, and they are not necessarily sorted. Recurse to resort and
3245 // resimplify any operands we just acquired.
3246 if (DeletedMul)
3247 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3248 }
3249
3250 // If there are any add recurrences in the operands list, see if any other
3251 // added values are loop invariant. If so, we can fold them into the
3252 // recurrence.
3253 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3254 ++Idx;
3255
3256 // Scan over all recurrences, trying to fold loop invariants into them.
3257 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3258 // Scan all of the other operands to this mul and add them to the vector
3259 // if they are loop invariant w.r.t. the recurrence.
3261 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3262 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3263 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3264 LIOps.push_back(Ops[i]);
3265 Ops.erase(Ops.begin()+i);
3266 --i; --e;
3267 }
3268
3269 // If we found some loop invariants, fold them into the recurrence.
3270 if (!LIOps.empty()) {
3271 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3273 NewOps.reserve(AddRec->getNumOperands());
3274 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3275
3276 // If both the mul and addrec are nuw, we can preserve nuw.
3277 // If both the mul and addrec are nsw, we can only preserve nsw if either
3278 // a) they are also nuw, or
3279 // b) all multiplications of addrec operands with scale are nsw.
3280 SCEV::NoWrapFlags Flags =
3281 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3282
3283 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3284 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3285 SCEV::FlagAnyWrap, Depth + 1));
3286
3287 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3289 Instruction::Mul, getSignedRange(Scale),
3291 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3292 Flags = clearFlags(Flags, SCEV::FlagNSW);
3293 }
3294 }
3295
3296 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3297
3298 // If all of the other operands were loop invariant, we are done.
3299 if (Ops.size() == 1) return NewRec;
3300
3301 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3302 for (unsigned i = 0;; ++i)
3303 if (Ops[i] == AddRec) {
3304 Ops[i] = NewRec;
3305 break;
3306 }
3307 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3308 }
3309
3310 // Okay, if there weren't any loop invariants to be folded, check to see
3311 // if there are multiple AddRec's with the same loop induction variable
3312 // being multiplied together. If so, we can fold them.
3313
3314 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3315 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3316 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3317 // ]]],+,...up to x=2n}.
3318 // Note that the arguments to choose() are always integers with values
3319 // known at compile time, never SCEV objects.
3320 //
3321 // The implementation avoids pointless extra computations when the two
3322 // addrec's are of different length (mathematically, it's equivalent to
3323 // an infinite stream of zeros on the right).
3324 bool OpsModified = false;
3325 for (unsigned OtherIdx = Idx+1;
3326 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3327 ++OtherIdx) {
3328 const SCEVAddRecExpr *OtherAddRec =
3329 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3330 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3331 continue;
3332
3333 // Limit max number of arguments to avoid creation of unreasonably big
3334 // SCEVAddRecs with very complex operands.
3335 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3336 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3337 continue;
3338
3339 bool Overflow = false;
3340 Type *Ty = AddRec->getType();
3341 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3343 for (int x = 0, xe = AddRec->getNumOperands() +
3344 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3345 SmallVector <const SCEV *, 7> SumOps;
3346 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3347 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3348 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3349 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3350 z < ze && !Overflow; ++z) {
3351 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3352 uint64_t Coeff;
3353 if (LargerThan64Bits)
3354 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3355 else
3356 Coeff = Coeff1*Coeff2;
3357 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3358 const SCEV *Term1 = AddRec->getOperand(y-z);
3359 const SCEV *Term2 = OtherAddRec->getOperand(z);
3360 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3361 SCEV::FlagAnyWrap, Depth + 1));
3362 }
3363 }
3364 if (SumOps.empty())
3365 SumOps.push_back(getZero(Ty));
3366 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3367 }
3368 if (!Overflow) {
3369 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3371 if (Ops.size() == 2) return NewAddRec;
3372 Ops[Idx] = NewAddRec;
3373 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3374 OpsModified = true;
3375 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3376 if (!AddRec)
3377 break;
3378 }
3379 }
3380 if (OpsModified)
3381 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3382
3383 // Otherwise couldn't fold anything into this recurrence. Move onto the
3384 // next one.
3385 }
3386
3387 // Okay, it looks like we really DO need an mul expr. Check to see if we
3388 // already have one, otherwise create a new one.
3389 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3390}
3391
3392/// Represents an unsigned remainder expression based on unsigned division.
3394 const SCEV *RHS) {
3397 "SCEVURemExpr operand types don't match!");
3398
3399 // Short-circuit easy cases
3400 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3401 // If constant is one, the result is trivial
3402 if (RHSC->getValue()->isOne())
3403 return getZero(LHS->getType()); // X urem 1 --> 0
3404
3405 // If constant is a power of two, fold into a zext(trunc(LHS)).
3406 if (RHSC->getAPInt().isPowerOf2()) {
3407 Type *FullTy = LHS->getType();
3408 Type *TruncTy =
3409 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3410 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3411 }
3412 }
3413
3414 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3415 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3416 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3417 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3418}
3419
3420/// Get a canonical unsigned division expression, or something simpler if
3421/// possible.
3423 const SCEV *RHS) {
3424 assert(!LHS->getType()->isPointerTy() &&
3425 "SCEVUDivExpr operand can't be pointer!");
3426 assert(LHS->getType() == RHS->getType() &&
3427 "SCEVUDivExpr operand types don't match!");
3428
3430 ID.AddInteger(scUDivExpr);
3431 ID.AddPointer(LHS);
3432 ID.AddPointer(RHS);
3433 void *IP = nullptr;
3434 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3435 return S;
3436
3437 // 0 udiv Y == 0
3438 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3439 if (LHSC->getValue()->isZero())
3440 return LHS;
3441
3442 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3443 if (RHSC->getValue()->isOne())
3444 return LHS; // X udiv 1 --> x
3445 // If the denominator is zero, the result of the udiv is undefined. Don't
3446 // try to analyze it, because the resolution chosen here may differ from
3447 // the resolution chosen in other parts of the compiler.
3448 if (!RHSC->getValue()->isZero()) {
3449 // Determine if the division can be folded into the operands of
3450 // its operands.
3451 // TODO: Generalize this to non-constants by using known-bits information.
3452 Type *Ty = LHS->getType();
3453 unsigned LZ = RHSC->getAPInt().countl_zero();
3454 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3455 // For non-power-of-two values, effectively round the value up to the
3456 // nearest power of two.
3457 if (!RHSC->getAPInt().isPowerOf2())
3458 ++MaxShiftAmt;
3459 IntegerType *ExtTy =
3460 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3461 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3462 if (const SCEVConstant *Step =
3463 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3464 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3465 const APInt &StepInt = Step->getAPInt();
3466 const APInt &DivInt = RHSC->getAPInt();
3467 if (!StepInt.urem(DivInt) &&
3468 getZeroExtendExpr(AR, ExtTy) ==
3469 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3470 getZeroExtendExpr(Step, ExtTy),
3471 AR->getLoop(), SCEV::FlagAnyWrap)) {
3473 for (const SCEV *Op : AR->operands())
3474 Operands.push_back(getUDivExpr(Op, RHS));
3475 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3476 }
3477 /// Get a canonical UDivExpr for a recurrence.
3478 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3479 // We can currently only fold X%N if X is constant.
3480 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3481 if (StartC && !DivInt.urem(StepInt) &&
3482 getZeroExtendExpr(AR, ExtTy) ==
3483 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3484 getZeroExtendExpr(Step, ExtTy),
3485 AR->getLoop(), SCEV::FlagAnyWrap)) {
3486 const APInt &StartInt = StartC->getAPInt();
3487 const APInt &StartRem = StartInt.urem(StepInt);
3488 if (StartRem != 0) {
3489 const SCEV *NewLHS =
3490 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3491 AR->getLoop(), SCEV::FlagNW);
3492 if (LHS != NewLHS) {
3493 LHS = NewLHS;
3494
3495 // Reset the ID to include the new LHS, and check if it is
3496 // already cached.
3497 ID.clear();
3498 ID.AddInteger(scUDivExpr);
3499 ID.AddPointer(LHS);
3500 ID.AddPointer(RHS);
3501 IP = nullptr;
3502 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3503 return S;
3504 }
3505 }
3506 }
3507 }
3508 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3509 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3511 for (const SCEV *Op : M->operands())
3512 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3513 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3514 // Find an operand that's safely divisible.
3515 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3516 const SCEV *Op = M->getOperand(i);
3517 const SCEV *Div = getUDivExpr(Op, RHSC);
3518 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3519 Operands = SmallVector<const SCEV *, 4>(M->operands());
3520 Operands[i] = Div;
3521 return getMulExpr(Operands);
3522 }
3523 }
3524 }
3525
3526 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3527 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3528 if (auto *DivisorConstant =
3529 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3530 bool Overflow = false;
3531 APInt NewRHS =
3532 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3533 if (Overflow) {
3534 return getConstant(RHSC->getType(), 0, false);
3535 }
3536 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3537 }
3538 }
3539
3540 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3541 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3543 for (const SCEV *Op : A->operands())
3544 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3545 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3546 Operands.clear();
3547 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3548 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3549 if (isa<SCEVUDivExpr>(Op) ||
3550 getMulExpr(Op, RHS) != A->getOperand(i))
3551 break;
3552 Operands.push_back(Op);
3553 }
3554 if (Operands.size() == A->getNumOperands())
3555 return getAddExpr(Operands);
3556 }
3557 }
3558
3559 // Fold if both operands are constant.
3560 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3561 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3562 }
3563 }
3564
3565 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3566 // changes). Make sure we get a new one.
3567 IP = nullptr;
3568 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3569 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3570 LHS, RHS);
3571 UniqueSCEVs.InsertNode(S, IP);
3572 registerUser(S, {LHS, RHS});
3573 return S;
3574}
3575
3576APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3577 APInt A = C1->getAPInt().abs();
3578 APInt B = C2->getAPInt().abs();
3579 uint32_t ABW = A.getBitWidth();
3580 uint32_t BBW = B.getBitWidth();
3581
3582 if (ABW > BBW)
3583 B = B.zext(ABW);
3584 else if (ABW < BBW)
3585 A = A.zext(BBW);
3586
3587 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3588}
3589
3590/// Get a canonical unsigned division expression, or something simpler if
3591/// possible. There is no representation for an exact udiv in SCEV IR, but we
3592/// can attempt to remove factors from the LHS and RHS. We can't do this when
3593/// it's not exact because the udiv may be clearing bits.
3595 const SCEV *RHS) {
3596 // TODO: we could try to find factors in all sorts of things, but for now we
3597 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3598 // end of this file for inspiration.
3599
3600 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3601 if (!Mul || !Mul->hasNoUnsignedWrap())
3602 return getUDivExpr(LHS, RHS);
3603
3604 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3605 // If the mulexpr multiplies by a constant, then that constant must be the
3606 // first element of the mulexpr.
3607 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3608 if (LHSCst == RHSCst) {
3610 return getMulExpr(Operands);
3611 }
3612
3613 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3614 // that there's a factor provided by one of the other terms. We need to
3615 // check.
3616 APInt Factor = gcd(LHSCst, RHSCst);
3617 if (!Factor.isIntN(1)) {
3618 LHSCst =
3619 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3620 RHSCst =
3621 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3623 Operands.push_back(LHSCst);
3624 append_range(Operands, Mul->operands().drop_front());
3626 RHS = RHSCst;
3627 Mul = dyn_cast<SCEVMulExpr>(LHS);
3628 if (!Mul)
3629 return getUDivExactExpr(LHS, RHS);
3630 }
3631 }
3632 }
3633
3634 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3635 if (Mul->getOperand(i) == RHS) {
3637 append_range(Operands, Mul->operands().take_front(i));
3638 append_range(Operands, Mul->operands().drop_front(i + 1));
3639 return getMulExpr(Operands);
3640 }
3641 }
3642
3643 return getUDivExpr(LHS, RHS);
3644}
3645
3646/// Get an add recurrence expression for the specified loop. Simplify the
3647/// expression as much as possible.
3648const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3649 const Loop *L,
3650 SCEV::NoWrapFlags Flags) {
3652 Operands.push_back(Start);
3653 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3654 if (StepChrec->getLoop() == L) {
3655 append_range(Operands, StepChrec->operands());
3656 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3657 }
3658
3659 Operands.push_back(Step);
3660 return getAddRecExpr(Operands, L, Flags);
3661}
3662
3663/// Get an add recurrence expression for the specified loop. Simplify the
3664/// expression as much as possible.
3665const SCEV *
3667 const Loop *L, SCEV::NoWrapFlags Flags) {
3668 if (Operands.size() == 1) return Operands[0];
3669#ifndef NDEBUG
3671 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3673 "SCEVAddRecExpr operand types don't match!");
3674 assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3675 }
3676 for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3678 "SCEVAddRecExpr operand is not available at loop entry!");
3679#endif
3680
3681 if (Operands.back()->isZero()) {
3682 Operands.pop_back();
3683 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3684 }
3685
3686 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3687 // use that information to infer NUW and NSW flags. However, computing a
3688 // BE count requires calling getAddRecExpr, so we may not yet have a
3689 // meaningful BE count at this point (and if we don't, we'd be stuck
3690 // with a SCEVCouldNotCompute as the cached BE count).
3691
3692 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3693
3694 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3695 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3696 const Loop *NestedLoop = NestedAR->getLoop();
3697 if (L->contains(NestedLoop)
3698 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3699 : (!NestedLoop->contains(L) &&
3700 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3701 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3702 Operands[0] = NestedAR->getStart();
3703 // AddRecs require their operands be loop-invariant with respect to their
3704 // loops. Don't perform this transformation if it would break this
3705 // requirement.
3706 bool AllInvariant = all_of(
3707 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3708
3709 if (AllInvariant) {
3710 // Create a recurrence for the outer loop with the same step size.
3711 //
3712 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3713 // inner recurrence has the same property.
3714 SCEV::NoWrapFlags OuterFlags =
3715 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3716
3717 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3718 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3719 return isLoopInvariant(Op, NestedLoop);
3720 });
3721
3722 if (AllInvariant) {
3723 // Ok, both add recurrences are valid after the transformation.
3724 //
3725 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3726 // the outer recurrence has the same property.
3727 SCEV::NoWrapFlags InnerFlags =
3728 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3729 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3730 }
3731 }
3732 // Reset Operands to its original state.
3733 Operands[0] = NestedAR;
3734 }
3735 }
3736
3737 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3738 // already have one, otherwise create a new one.
3739 return getOrCreateAddRecExpr(Operands, L, Flags);
3740}
3741
3742const SCEV *
3744 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3745 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3746 // getSCEV(Base)->getType() has the same address space as Base->getType()
3747 // because SCEV::getType() preserves the address space.
3748 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3749 const bool AssumeInBoundsFlags = [&]() {
3750 if (!GEP->isInBounds())
3751 return false;
3752
3753 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3754 // but to do that, we have to ensure that said flag is valid in the entire
3755 // defined scope of the SCEV.
3756 auto *GEPI = dyn_cast<Instruction>(GEP);
3757 // TODO: non-instructions have global scope. We might be able to prove
3758 // some global scope cases
3759 return GEPI && isSCEVExprNeverPoison(GEPI);
3760 }();
3761
3762 SCEV::NoWrapFlags OffsetWrap =
3763 AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3764
3765 Type *CurTy = GEP->getType();
3766 bool FirstIter = true;
3768 for (const SCEV *IndexExpr : IndexExprs) {
3769 // Compute the (potentially symbolic) offset in bytes for this index.
3770 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3771 // For a struct, add the member offset.
3772 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3773 unsigned FieldNo = Index->getZExtValue();
3774 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3775 Offsets.push_back(FieldOffset);
3776
3777 // Update CurTy to the type of the field at Index.
3778 CurTy = STy->getTypeAtIndex(Index);
3779 } else {
3780 // Update CurTy to its element type.
3781 if (FirstIter) {
3782 assert(isa<PointerType>(CurTy) &&
3783 "The first index of a GEP indexes a pointer");
3784 CurTy = GEP->getSourceElementType();
3785 FirstIter = false;
3786 } else {
3788 }
3789 // For an array, add the element offset, explicitly scaled.
3790 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3791 // Getelementptr indices are signed.
3792 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3793
3794 // Multiply the index by the element size to compute the element offset.
3795 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3796 Offsets.push_back(LocalOffset);
3797 }
3798 }
3799
3800 // Handle degenerate case of GEP without offsets.
3801 if (Offsets.empty())
3802 return BaseExpr;
3803
3804 // Add the offsets together, assuming nsw if inbounds.
3805 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3806 // Add the base address and the offset. We cannot use the nsw flag, as the
3807 // base address is unsigned. However, if we know that the offset is
3808 // non-negative, we can use nuw.
3809 SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3811 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3812 assert(BaseExpr->getType() == GEPExpr->getType() &&
3813 "GEP should not change type mid-flight.");
3814 return GEPExpr;
3815}
3816
3817SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3820 ID.AddInteger(SCEVType);
3821 for (const SCEV *Op : Ops)
3822 ID.AddPointer(Op);
3823 void *IP = nullptr;
3824 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3825}
3826
3827const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3829 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3830}
3831
3834 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3835 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3836 if (Ops.size() == 1) return Ops[0];
3837#ifndef NDEBUG
3838 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3839 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3840 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3841 "Operand types don't match!");
3842 assert(Ops[0]->getType()->isPointerTy() ==
3843 Ops[i]->getType()->isPointerTy() &&
3844 "min/max should be consistently pointerish");
3845 }
3846#endif
3847
3848 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3849 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3850
3851 // Sort by complexity, this groups all similar expression types together.
3852 GroupByComplexity(Ops, &LI, DT);
3853
3854 // Check if we have created the same expression before.
3855 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3856 return S;
3857 }
3858
3859 // If there are any constants, fold them together.
3860 unsigned Idx = 0;
3861 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3862 ++Idx;
3863 assert(Idx < Ops.size());
3864 auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3865 switch (Kind) {
3866 case scSMaxExpr:
3867 return APIntOps::smax(LHS, RHS);
3868 case scSMinExpr:
3869 return APIntOps::smin(LHS, RHS);
3870 case scUMaxExpr:
3871 return APIntOps::umax(LHS, RHS);
3872 case scUMinExpr:
3873 return APIntOps::umin(LHS, RHS);
3874 default:
3875 llvm_unreachable("Unknown SCEV min/max opcode");
3876 }
3877 };
3878
3879 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3880 // We found two constants, fold them together!
3881 ConstantInt *Fold = ConstantInt::get(
3882 getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3883 Ops[0] = getConstant(Fold);
3884 Ops.erase(Ops.begin()+1); // Erase the folded element
3885 if (Ops.size() == 1) return Ops[0];
3886 LHSC = cast<SCEVConstant>(Ops[0]);
3887 }
3888
3889 bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3890 bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3891
3892 if (IsMax ? IsMinV : IsMaxV) {
3893 // If we are left with a constant minimum(/maximum)-int, strip it off.
3894 Ops.erase(Ops.begin());
3895 --Idx;
3896 } else if (IsMax ? IsMaxV : IsMinV) {
3897 // If we have a max(/min) with a constant maximum(/minimum)-int,
3898 // it will always be the extremum.
3899 return LHSC;
3900 }
3901
3902 if (Ops.size() == 1) return Ops[0];
3903 }
3904
3905 // Find the first operation of the same kind
3906 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3907 ++Idx;
3908
3909 // Check to see if one of the operands is of the same kind. If so, expand its
3910 // operands onto our operand list, and recurse to simplify.
3911 if (Idx < Ops.size()) {
3912 bool DeletedAny = false;
3913 while (Ops[Idx]->getSCEVType() == Kind) {
3914 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3915 Ops.erase(Ops.begin()+Idx);
3916 append_range(Ops, SMME->operands());
3917 DeletedAny = true;
3918 }
3919
3920 if (DeletedAny)
3921 return getMinMaxExpr(Kind, Ops);
3922 }
3923
3924 // Okay, check to see if the same value occurs in the operand list twice. If
3925 // so, delete one. Since we sorted the list, these values are required to
3926 // be adjacent.
3931 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3932 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3933 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3934 if (Ops[i] == Ops[i + 1] ||
3935 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3936 // X op Y op Y --> X op Y
3937 // X op Y --> X, if we know X, Y are ordered appropriately
3938 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3939 --i;
3940 --e;
3941 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3942 Ops[i + 1])) {
3943 // X op Y --> Y, if we know X, Y are ordered appropriately
3944 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3945 --i;
3946 --e;
3947 }
3948 }
3949
3950 if (Ops.size() == 1) return Ops[0];
3951
3952 assert(!Ops.empty() && "Reduced smax down to nothing!");
3953
3954 // Okay, it looks like we really DO need an expr. Check to see if we
3955 // already have one, otherwise create a new one.
3957 ID.AddInteger(Kind);
3958 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3959 ID.AddPointer(Ops[i]);
3960 void *IP = nullptr;
3961 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3962 if (ExistingSCEV)
3963 return ExistingSCEV;
3964 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3965 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3966 SCEV *S = new (SCEVAllocator)
3967 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3968
3969 UniqueSCEVs.InsertNode(S, IP);
3970 registerUser(S, Ops);
3971 return S;
3972}
3973
3974namespace {
3975
3976class SCEVSequentialMinMaxDeduplicatingVisitor final
3977 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3978 std::optional<const SCEV *>> {
3979 using RetVal = std::optional<const SCEV *>;
3981
3982 ScalarEvolution &SE;
3983 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3984 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3986
3987 bool canRecurseInto(SCEVTypes Kind) const {
3988 // We can only recurse into the SCEV expression of the same effective type
3989 // as the type of our root SCEV expression.
3990 return RootKind == Kind || NonSequentialRootKind == Kind;
3991 };
3992
3993 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3994 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3995 "Only for min/max expressions.");
3996 SCEVTypes Kind = S->getSCEVType();
3997
3998 if (!canRecurseInto(Kind))
3999 return S;
4000
4001 auto *NAry = cast<SCEVNAryExpr>(S);
4003 bool Changed = visit(Kind, NAry->operands(), NewOps);
4004
4005 if (!Changed)
4006 return S;
4007 if (NewOps.empty())
4008 return std::nullopt;
4009
4010 return isa<SCEVSequentialMinMaxExpr>(S)
4011 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4012 : SE.getMinMaxExpr(Kind, NewOps);
4013 }
4014
4015 RetVal visit(const SCEV *S) {
4016 // Has the whole operand been seen already?
4017 if (!SeenOps.insert(S).second)
4018 return std::nullopt;
4019 return Base::visit(S);
4020 }
4021
4022public:
4023 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4024 SCEVTypes RootKind)
4025 : SE(SE), RootKind(RootKind),
4026 NonSequentialRootKind(
4027 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4028 RootKind)) {}
4029
4030 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4032 bool Changed = false;
4034 Ops.reserve(OrigOps.size());
4035
4036 for (const SCEV *Op : OrigOps) {
4037 RetVal NewOp = visit(Op);
4038 if (NewOp != Op)
4039 Changed = true;
4040 if (NewOp)
4041 Ops.emplace_back(*NewOp);
4042 }
4043
4044 if (Changed)
4045 NewOps = std::move(Ops);
4046 return Changed;
4047 }
4048
4049 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4050
4051 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4052
4053 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4054
4055 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4056
4057 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4058
4059 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4060
4061 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4062
4063 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4064
4065 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4066
4067 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4068
4069 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4070 return visitAnyMinMaxExpr(Expr);
4071 }
4072
4073 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4074 return visitAnyMinMaxExpr(Expr);
4075 }
4076
4077 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4078 return visitAnyMinMaxExpr(Expr);
4079 }
4080
4081 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4082 return visitAnyMinMaxExpr(Expr);
4083 }
4084
4085 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4086 return visitAnyMinMaxExpr(Expr);
4087 }
4088
4089 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4090
4091 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4092};
4093
4094} // namespace
4095
4097 switch (Kind) {
4098 case scConstant:
4099 case scVScale:
4100 case scTruncate:
4101 case scZeroExtend:
4102 case scSignExtend:
4103 case scPtrToInt:
4104 case scAddExpr:
4105 case scMulExpr:
4106 case scUDivExpr:
4107 case scAddRecExpr:
4108 case scUMaxExpr:
4109 case scSMaxExpr:
4110 case scUMinExpr:
4111 case scSMinExpr:
4112 case scUnknown:
4113 // If any operand is poison, the whole expression is poison.
4114 return true;
4116 // FIXME: if the *first* operand is poison, the whole expression is poison.
4117 return false; // Pessimistically, say that it does not propagate poison.
4118 case scCouldNotCompute:
4119 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4120 }
4121 llvm_unreachable("Unknown SCEV kind!");
4122}
4123
4124namespace {
4125// The only way poison may be introduced in a SCEV expression is from a
4126// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4127// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4128// introduce poison -- they encode guaranteed, non-speculated knowledge.
4129//
4130// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4131// with the notable exception of umin_seq, where only poison from the first
4132// operand is (unconditionally) propagated.
4133struct SCEVPoisonCollector {
4134 bool LookThroughMaybePoisonBlocking;
4136 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4137 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4138
4139 bool follow(const SCEV *S) {
4140 if (!LookThroughMaybePoisonBlocking &&
4142 return false;
4143
4144 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4145 if (!isGuaranteedNotToBePoison(SU->getValue()))
4146 MaybePoison.insert(SU);
4147 }
4148 return true;
4149 }
4150 bool isDone() const { return false; }
4151};
4152} // namespace
4153
4154/// Return true if V is poison given that AssumedPoison is already poison.
4155static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4156 // First collect all SCEVs that might result in AssumedPoison to be poison.
4157 // We need to look through potentially poison-blocking operations here,
4158 // because we want to find all SCEVs that *might* result in poison, not only
4159 // those that are *required* to.
4160 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4161 visitAll(AssumedPoison, PC1);
4162
4163 // AssumedPoison is never poison. As the assumption is false, the implication
4164 // is true. Don't bother walking the other SCEV in this case.
4165 if (PC1.MaybePoison.empty())
4166 return true;
4167
4168 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4169 // as well. We cannot look through potentially poison-blocking operations
4170 // here, as their arguments only *may* make the result poison.
4171 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4172 visitAll(S, PC2);
4173
4174 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4175 // it will also make S poison by being part of PC2.MaybePoison.
4176 return all_of(PC1.MaybePoison, [&](const SCEVUnknown *S) {
4177 return PC2.MaybePoison.contains(S);
4178 });
4179}
4180
4182 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4183 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4184 visitAll(S, PC);
4185 for (const SCEVUnknown *SU : PC.MaybePoison)
4186 Result.insert(SU->getValue());
4187}
4188
4190 const SCEV *S, Instruction *I,
4191 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4192 // If the instruction cannot be poison, it's always safe to reuse.
4194 return true;
4195
4196 // Otherwise, it is possible that I is more poisonous that S. Collect the
4197 // poison-contributors of S, and then check whether I has any additional
4198 // poison-contributors. Poison that is contributed through poison-generating
4199 // flags is handled by dropping those flags instead.
4201 getPoisonGeneratingValues(PoisonVals, S);
4202
4203 SmallVector<Value *> Worklist;
4205 Worklist.push_back(I);
4206 while (!Worklist.empty()) {
4207 Value *V = Worklist.pop_back_val();
4208 if (!Visited.insert(V).second)
4209 continue;
4210
4211 // Avoid walking large instruction graphs.
4212 if (Visited.size() > 16)
4213 return false;
4214
4215 // Either the value can't be poison, or the S would also be poison if it
4216 // is.
4217 if (PoisonVals.contains(V) || isGuaranteedNotToBePoison(V))
4218 continue;
4219
4220 auto *I = dyn_cast<Instruction>(V);
4221 if (!I)
4222 return false;
4223
4224 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4225 // can't replace an arbitrary add with disjoint or, even if we drop the
4226 // flag. We would need to convert the or into an add.
4227 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4228 if (PDI->isDisjoint())
4229 return false;
4230
4231 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4232 // because SCEV currently assumes it can't be poison. Remove this special
4233 // case once we proper model when vscale can be poison.
4234 if (auto *II = dyn_cast<IntrinsicInst>(I);
4235 II && II->getIntrinsicID() == Intrinsic::vscale)
4236 continue;
4237
4238 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4239 return false;
4240
4241 // If the instruction can't create poison, we can recurse to its operands.
4242 if (I->hasPoisonGeneratingAnnotations())
4243 DropPoisonGeneratingInsts.push_back(I);
4244
4245 for (Value *Op : I->operands())
4246 Worklist.push_back(Op);
4247 }
4248 return true;
4249}
4250
4251const SCEV *
4254 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4255 "Not a SCEVSequentialMinMaxExpr!");
4256 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4257 if (Ops.size() == 1)
4258 return Ops[0];
4259#ifndef NDEBUG
4260 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4261 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4262 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4263 "Operand types don't match!");
4264 assert(Ops[0]->getType()->isPointerTy() ==
4265 Ops[i]->getType()->isPointerTy() &&
4266 "min/max should be consistently pointerish");
4267 }
4268#endif
4269
4270 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4271 // so we can *NOT* do any kind of sorting of the expressions!
4272
4273 // Check if we have created the same expression before.
4274 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4275 return S;
4276
4277 // FIXME: there are *some* simplifications that we can do here.
4278
4279 // Keep only the first instance of an operand.
4280 {
4281 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4282 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4283 if (Changed)
4284 return getSequentialMinMaxExpr(Kind, Ops);
4285 }
4286
4287 // Check to see if one of the operands is of the same kind. If so, expand its
4288 // operands onto our operand list, and recurse to simplify.
4289 {
4290 unsigned Idx = 0;
4291 bool DeletedAny = false;
4292 while (Idx < Ops.size()) {
4293 if (Ops[Idx]->getSCEVType() != Kind) {
4294 ++Idx;
4295 continue;
4296 }
4297 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4298 Ops.erase(Ops.begin() + Idx);
4299 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4300 SMME->operands().end());
4301 DeletedAny = true;
4302 }
4303
4304 if (DeletedAny)
4305 return getSequentialMinMaxExpr(Kind, Ops);
4306 }
4307
4308 const SCEV *SaturationPoint;
4310 switch (Kind) {
4312 SaturationPoint = getZero(Ops[0]->getType());
4313 Pred = ICmpInst::ICMP_ULE;
4314 break;
4315 default:
4316 llvm_unreachable("Not a sequential min/max type.");
4317 }
4318
4319 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4320 // We can replace %x umin_seq %y with %x umin %y if either:
4321 // * %y being poison implies %x is also poison.
4322 // * %x cannot be the saturating value (e.g. zero for umin).
4323 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4324 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4325 SaturationPoint)) {
4326 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4327 Ops[i - 1] = getMinMaxExpr(
4329 SeqOps);
4330 Ops.erase(Ops.begin() + i);
4331 return getSequentialMinMaxExpr(Kind, Ops);
4332 }
4333 // Fold %x umin_seq %y to %x if %x ule %y.
4334 // TODO: We might be able to prove the predicate for a later operand.
4335 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4336 Ops.erase(Ops.begin() + i);
4337 return getSequentialMinMaxExpr(Kind, Ops);
4338 }
4339 }
4340
4341 // Okay, it looks like we really DO need an expr. Check to see if we
4342 // already have one, otherwise create a new one.
4344 ID.AddInteger(Kind);
4345 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4346 ID.AddPointer(Ops[i]);
4347 void *IP = nullptr;
4348 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4349 if (ExistingSCEV)
4350 return ExistingSCEV;
4351
4352 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4353 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4354 SCEV *S = new (SCEVAllocator)
4355 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4356
4357 UniqueSCEVs.InsertNode(S, IP);
4358 registerUser(S, Ops);
4359 return S;
4360}
4361
4362const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4364 return getSMaxExpr(Ops);
4365}
4366
4368 return getMinMaxExpr(scSMaxExpr, Ops);
4369}
4370
4371const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4373 return getUMaxExpr(Ops);
4374}
4375
4377 return getMinMaxExpr(scUMaxExpr, Ops);
4378}
4379
4381 const SCEV *RHS) {
4383 return getSMinExpr(Ops);
4384}
4385
4387 return getMinMaxExpr(scSMinExpr, Ops);
4388}
4389
4390const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4391 bool Sequential) {
4393 return getUMinExpr(Ops, Sequential);
4394}
4395
4397 bool Sequential) {
4398 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4399 : getMinMaxExpr(scUMinExpr, Ops);
4400}
4401
4402const SCEV *
4404 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4405 if (Size.isScalable())
4406 Res = getMulExpr(Res, getVScale(IntTy));
4407 return Res;
4408}
4409
4411 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4412}
4413
4415 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4416}
4417
4419 StructType *STy,
4420 unsigned FieldNo) {
4421 // We can bypass creating a target-independent constant expression and then
4422 // folding it back into a ConstantInt. This is just a compile-time
4423 // optimization.
4424 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4425 assert(!SL->getSizeInBits().isScalable() &&
4426 "Cannot get offset for structure containing scalable vector types");
4427 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4428}
4429
4431 // Don't attempt to do anything other than create a SCEVUnknown object
4432 // here. createSCEV only calls getUnknown after checking for all other
4433 // interesting possibilities, and any other code that calls getUnknown
4434 // is doing so in order to hide a value from SCEV canonicalization.
4435
4437 ID.AddInteger(scUnknown);
4438 ID.AddPointer(V);
4439 void *IP = nullptr;
4440 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4441 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4442 "Stale SCEVUnknown in uniquing map!");
4443 return S;
4444 }
4445 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4446 FirstUnknown);
4447 FirstUnknown = cast<SCEVUnknown>(S);
4448 UniqueSCEVs.InsertNode(S, IP);
4449 return S;
4450}
4451
4452//===----------------------------------------------------------------------===//
4453// Basic SCEV Analysis and PHI Idiom Recognition Code
4454//
4455
4456/// Test if values of the given type are analyzable within the SCEV
4457/// framework. This primarily includes integer types, and it can optionally
4458/// include pointer types if the ScalarEvolution class has access to
4459/// target-specific information.
4461 // Integers and pointers are always SCEVable.
4462 return Ty->isIntOrPtrTy();
4463}
4464
4465/// Return the size in bits of the specified type, for which isSCEVable must
4466/// return true.
4468 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4469 if (Ty->isPointerTy())
4471 return getDataLayout().getTypeSizeInBits(Ty);
4472}
4473
4474/// Return a type with the same bitwidth as the given type and which represents
4475/// how SCEV will treat the given type, for which isSCEVable must return
4476/// true. For pointer types, this is the pointer index sized integer type.
4478 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4479
4480 if (Ty->isIntegerTy())
4481 return Ty;
4482
4483 // The only other support type is pointer.
4484 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4485 return getDataLayout().getIndexType(Ty);
4486}
4487
4489 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4490}
4491
4493 const SCEV *B) {
4494 /// For a valid use point to exist, the defining scope of one operand
4495 /// must dominate the other.
4496 bool PreciseA, PreciseB;
4497 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4498 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4499 if (!PreciseA || !PreciseB)
4500 // Can't tell.
4501 return false;
4502 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4503 DT.dominates(ScopeB, ScopeA);
4504}
4505
4507 return CouldNotCompute.get();
4508}
4509
4510bool ScalarEvolution::checkValidity(const SCEV *S) const {
4511 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4512 auto *SU = dyn_cast<SCEVUnknown>(S);
4513 return SU && SU->getValue() == nullptr;
4514 });
4515
4516 return !ContainsNulls;
4517}
4518
4520 HasRecMapType::iterator I = HasRecMap.find(S);
4521 if (I != HasRecMap.end())
4522 return I->second;
4523
4524 bool FoundAddRec =
4525 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4526 HasRecMap.insert({S, FoundAddRec});
4527 return FoundAddRec;
4528}
4529
4530/// Return the ValueOffsetPair set for \p S. \p S can be represented
4531/// by the value and offset from any ValueOffsetPair in the set.
4532ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4533 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4534 if (SI == ExprValueMap.end())
4535 return std::nullopt;
4536 return SI->second.getArrayRef();
4537}
4538
4539/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4540/// cannot be used separately. eraseValueFromMap should be used to remove
4541/// V from ValueExprMap and ExprValueMap at the same time.
4542void ScalarEvolution::eraseValueFromMap(Value *V) {
4543 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4544 if (I != ValueExprMap.end()) {
4545 auto EVIt = ExprValueMap.find(I->second);
4546 bool Removed = EVIt->second.remove(V);
4547 (void) Removed;
4548 assert(Removed && "Value not in ExprValueMap?");
4549 ValueExprMap.erase(I);
4550 }
4551}
4552
4553void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4554 // A recursive query may have already computed the SCEV. It should be
4555 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4556 // inferred nowrap flags.
4557 auto It = ValueExprMap.find_as(V);
4558 if (It == ValueExprMap.end()) {
4559 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4560 ExprValueMap[S].insert(V);
4561 }
4562}
4563
4564/// Return an existing SCEV if it exists, otherwise analyze the expression and
4565/// create a new one.
4567 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4568
4569 if (const SCEV *S = getExistingSCEV(V))
4570 return S;
4571 return createSCEVIter(V);
4572}
4573
4575 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4576
4577 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4578 if (I != ValueExprMap.end()) {
4579 const SCEV *S = I->second;
4580 assert(checkValidity(S) &&
4581 "existing SCEV has not been properly invalidated");
4582 return S;
4583 }
4584 return nullptr;
4585}
4586
4587/// Return a SCEV corresponding to -V = -1*V
4589 SCEV::NoWrapFlags Flags) {
4590 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4591 return getConstant(
4592 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4593
4594 Type *Ty = V->getType();
4595 Ty = getEffectiveSCEVType(Ty);
4596 return getMulExpr(V, getMinusOne(Ty), Flags);
4597}
4598
4599/// If Expr computes ~A, return A else return nullptr
4600static const SCEV *MatchNotExpr(const SCEV *Expr) {
4601 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4602 if (!Add || Add->getNumOperands() != 2 ||
4603 !Add->getOperand(0)->isAllOnesValue())
4604 return nullptr;
4605
4606 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4607 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4608 !AddRHS->getOperand(0)->isAllOnesValue())
4609 return nullptr;
4610
4611 return AddRHS->getOperand(1);
4612}
4613
4614/// Return a SCEV corresponding to ~V = -1-V
4616 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4617
4618 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4619 return getConstant(
4620 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4621
4622 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4623 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4624 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4625 SmallVector<const SCEV *, 2> MatchedOperands;
4626 for (const SCEV *Operand : MME->operands()) {
4627 const SCEV *Matched = MatchNotExpr(Operand);
4628 if (!Matched)
4629 return (const SCEV *)nullptr;
4630 MatchedOperands.push_back(Matched);
4631 }
4632 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4633 MatchedOperands);
4634 };
4635 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4636 return Replaced;
4637 }
4638
4639 Type *Ty = V->getType();
4640 Ty = getEffectiveSCEVType(Ty);
4641 return getMinusSCEV(getMinusOne(Ty), V);
4642}
4643
4645 assert(P->getType()->isPointerTy());
4646
4647 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4648 // The base of an AddRec is the first operand.
4649 SmallVector<const SCEV *> Ops{AddRec->operands()};
4650 Ops[0] = removePointerBase(Ops[0]);
4651 // Don't try to transfer nowrap flags for now. We could in some cases
4652 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4653 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4654 }
4655 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4656 // The base of an Add is the pointer operand.
4657 SmallVector<const SCEV *> Ops{Add->operands()};
4658 const SCEV **PtrOp = nullptr;
4659 for (const SCEV *&AddOp : Ops) {
4660 if (AddOp->getType()->isPointerTy()) {
4661 assert(!PtrOp && "Cannot have multiple pointer ops");
4662 PtrOp = &AddOp;
4663 }
4664 }
4665 *PtrOp = removePointerBase(*PtrOp);
4666 // Don't try to transfer nowrap flags for now. We could in some cases
4667 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4668 return getAddExpr(Ops);
4669 }
4670 // Any other expression must be a pointer base.
4671 return getZero(P->getType());
4672}
4673
4674const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4675 SCEV::NoWrapFlags Flags,
4676 unsigned Depth) {
4677 // Fast path: X - X --> 0.
4678 if (LHS == RHS)
4679 return getZero(LHS->getType());
4680
4681 // If we subtract two pointers with different pointer bases, bail.
4682 // Eventually, we're going to add an assertion to getMulExpr that we
4683 // can't multiply by a pointer.
4684 if (RHS->getType()->isPointerTy()) {
4685 if (!LHS->getType()->isPointerTy() ||
4687 return getCouldNotCompute();
4690 }
4691
4692 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4693 // makes it so that we cannot make much use of NUW.
4694 auto AddFlags = SCEV::FlagAnyWrap;
4695 const bool RHSIsNotMinSigned =
4697 if (hasFlags(Flags, SCEV::FlagNSW)) {
4698 // Let M be the minimum representable signed value. Then (-1)*RHS
4699 // signed-wraps if and only if RHS is M. That can happen even for
4700 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4701 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4702 // (-1)*RHS, we need to prove that RHS != M.
4703 //
4704 // If LHS is non-negative and we know that LHS - RHS does not
4705 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4706 // either by proving that RHS > M or that LHS >= 0.
4707 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4708 AddFlags = SCEV::FlagNSW;
4709 }
4710 }
4711
4712 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4713 // RHS is NSW and LHS >= 0.
4714 //
4715 // The difficulty here is that the NSW flag may have been proven
4716 // relative to a loop that is to be found in a recurrence in LHS and
4717 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4718 // larger scope than intended.
4719 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4720
4721 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4722}
4723
4725 unsigned Depth) {
4726 Type *SrcTy = V->getType();
4727 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4728 "Cannot truncate or zero extend with non-integer arguments!");
4729 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4730 return V; // No conversion
4731 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4732 return getTruncateExpr(V, Ty, Depth);
4733 return getZeroExtendExpr(V, Ty, Depth);
4734}
4735
4737 unsigned Depth) {
4738 Type *SrcTy = V->getType();
4739 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4740 "Cannot truncate or zero extend with non-integer arguments!");
4741 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4742 return V; // No conversion
4743 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4744 return getTruncateExpr(V, Ty, Depth);
4745 return getSignExtendExpr(V, Ty, Depth);
4746}
4747
4748const SCEV *
4750 Type *SrcTy = V->getType();
4751 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4752 "Cannot noop or zero extend with non-integer arguments!");
4754 "getNoopOrZeroExtend cannot truncate!");
4755 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4756 return V; // No conversion
4757 return getZeroExtendExpr(V, Ty);
4758}
4759
4760const SCEV *
4762 Type *SrcTy = V->getType();
4763 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4764 "Cannot noop or sign extend with non-integer arguments!");
4766 "getNoopOrSignExtend cannot truncate!");
4767 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4768 return V; // No conversion
4769 return getSignExtendExpr(V, Ty);
4770}
4771
4772const SCEV *
4774 Type *SrcTy = V->getType();
4775 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4776 "Cannot noop or any extend with non-integer arguments!");
4778 "getNoopOrAnyExtend cannot truncate!");
4779 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4780 return V; // No conversion
4781 return getAnyExtendExpr(V, Ty);
4782}
4783
4784const SCEV *
4786 Type *SrcTy = V->getType();
4787 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4788 "Cannot truncate or noop with non-integer arguments!");
4790 "getTruncateOrNoop cannot extend!");
4791 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4792 return V; // No conversion
4793 return getTruncateExpr(V, Ty);
4794}
4795
4797 const SCEV *RHS) {
4798 const SCEV *PromotedLHS = LHS;
4799 const SCEV *PromotedRHS = RHS;
4800
4802 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4803 else
4804 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4805
4806 return getUMaxExpr(PromotedLHS, PromotedRHS);
4807}
4808
4810 const SCEV *RHS,
4811 bool Sequential) {
4813 return getUMinFromMismatchedTypes(Ops, Sequential);
4814}
4815
4816const SCEV *
4818 bool Sequential) {
4819 assert(!Ops.empty() && "At least one operand must be!");
4820 // Trivial case.
4821 if (Ops.size() == 1)
4822 return Ops[0];
4823
4824 // Find the max type first.
4825 Type *MaxType = nullptr;
4826 for (const auto *S : Ops)
4827 if (MaxType)
4828 MaxType = getWiderType(MaxType, S->getType());
4829 else
4830 MaxType = S->getType();
4831 assert(MaxType && "Failed to find maximum type!");
4832
4833 // Extend all ops to max type.
4834 SmallVector<const SCEV *, 2> PromotedOps;
4835 for (const auto *S : Ops)
4836 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4837
4838 // Generate umin.
4839 return getUMinExpr(PromotedOps, Sequential);
4840}
4841
4843 // A pointer operand may evaluate to a nonpointer expression, such as null.
4844 if (!V->getType()->isPointerTy())
4845 return V;
4846
4847 while (true) {
4848 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4849 V = AddRec->getStart();
4850 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4851 const SCEV *PtrOp = nullptr;
4852 for (const SCEV *AddOp : Add->operands()) {
4853 if (AddOp->getType()->isPointerTy()) {
4854 assert(!PtrOp && "Cannot have multiple pointer ops");
4855 PtrOp = AddOp;
4856 }
4857 }
4858 assert(PtrOp && "Must have pointer op");
4859 V = PtrOp;
4860 } else // Not something we can look further into.
4861 return V;
4862 }
4863}
4864
4865/// Push users of the given Instruction onto the given Worklist.
4869 // Push the def-use children onto the Worklist stack.
4870 for (User *U : I->users()) {
4871 auto *UserInsn = cast<Instruction>(U);
4872 if (Visited.insert(UserInsn).second)
4873 Worklist.push_back(UserInsn);
4874 }
4875}
4876
4877namespace {
4878
4879/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4880/// expression in case its Loop is L. If it is not L then
4881/// if IgnoreOtherLoops is true then use AddRec itself
4882/// otherwise rewrite cannot be done.
4883/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4884class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4885public:
4886 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4887 bool IgnoreOtherLoops = true) {
4888 SCEVInitRewriter Rewriter(L, SE);
4889 const SCEV *Result = Rewriter.visit(S);
4890 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4891 return SE.getCouldNotCompute();
4892 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4893 ? SE.getCouldNotCompute()
4894 : Result;
4895 }
4896
4897 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4898 if (!SE.isLoopInvariant(Expr, L))
4899 SeenLoopVariantSCEVUnknown = true;
4900 return Expr;
4901 }
4902
4903 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4904 // Only re-write AddRecExprs for this loop.
4905 if (Expr->getLoop() == L)
4906 return Expr->getStart();
4907 SeenOtherLoops = true;
4908 return Expr;
4909 }
4910
4911 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4912
4913 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4914
4915private:
4916 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4917 : SCEVRewriteVisitor(SE), L(L) {}
4918
4919 const Loop *L;
4920 bool SeenLoopVariantSCEVUnknown = false;
4921 bool SeenOtherLoops = false;
4922};
4923
4924/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4925/// increment expression in case its Loop is L. If it is not L then
4926/// use AddRec itself.
4927/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4928class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4929public:
4930 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4931 SCEVPostIncRewriter Rewriter(L, SE);
4932 const SCEV *Result = Rewriter.visit(S);
4933 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4934 ? SE.getCouldNotCompute()
4935 : Result;
4936 }
4937
4938 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4939 if (!SE.isLoopInvariant(Expr, L))
4940 SeenLoopVariantSCEVUnknown = true;
4941 return Expr;
4942 }
4943
4944 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4945 // Only re-write AddRecExprs for this loop.
4946 if (Expr->getLoop() == L)
4947 return Expr->getPostIncExpr(SE);
4948 SeenOtherLoops = true;
4949 return Expr;
4950 }
4951
4952 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4953
4954 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4955
4956private:
4957 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4958 : SCEVRewriteVisitor(SE), L(L) {}
4959
4960 const Loop *L;
4961 bool SeenLoopVariantSCEVUnknown = false;
4962 bool SeenOtherLoops = false;
4963};
4964
4965/// This class evaluates the compare condition by matching it against the
4966/// condition of loop latch. If there is a match we assume a true value
4967/// for the condition while building SCEV nodes.
4968class SCEVBackedgeConditionFolder
4969 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4970public:
4971 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4972 ScalarEvolution &SE) {
4973 bool IsPosBECond = false;
4974 Value *BECond = nullptr;
4975 if (BasicBlock *Latch = L->getLoopLatch()) {
4976 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4977 if (BI && BI->isConditional()) {
4978 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4979 "Both outgoing branches should not target same header!");
4980 BECond = BI->getCondition();
4981 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4982 } else {
4983 return S;
4984 }
4985 }
4986 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4987 return Rewriter.visit(S);
4988 }
4989
4990 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4991 const SCEV *Result = Expr;
4992 bool InvariantF = SE.isLoopInvariant(Expr, L);
4993
4994 if (!InvariantF) {
4995 Instruction *I = cast<Instruction>(Expr->getValue());
4996 switch (I->getOpcode()) {
4997 case Instruction::Select: {
4998 SelectInst *SI = cast<SelectInst>(I);
4999 std::optional<const SCEV *> Res =
5000 compareWithBackedgeCondition(SI->getCondition());
5001 if (Res) {
5002 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5003 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5004 }
5005 break;
5006 }
5007 default: {
5008 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5009 if (Res)
5010 Result = *Res;
5011 break;
5012 }
5013 }
5014 }
5015 return Result;
5016 }
5017
5018private:
5019 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5020 bool IsPosBECond, ScalarEvolution &SE)
5021 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5022 IsPositiveBECond(IsPosBECond) {}
5023
5024 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5025
5026 const Loop *L;
5027 /// Loop back condition.
5028 Value *BackedgeCond = nullptr;
5029 /// Set to true if loop back is on positive branch condition.
5030 bool IsPositiveBECond;
5031};
5032
5033std::optional<const SCEV *>
5034SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5035
5036 // If value matches the backedge condition for loop latch,
5037 // then return a constant evolution node based on loopback
5038 // branch taken.
5039 if (BackedgeCond == IC)
5040 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5042 return std::nullopt;
5043}
5044
5045class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5046public:
5047 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5048 ScalarEvolution &SE) {
5049 SCEVShiftRewriter Rewriter(L, SE);
5050 const SCEV *Result = Rewriter.visit(S);
5051 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5052 }
5053
5054 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5055 // Only allow AddRecExprs for this loop.
5056 if (!SE.isLoopInvariant(Expr, L))
5057 Valid = false;
5058 return Expr;
5059 }
5060
5061 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5062 if (Expr->getLoop() == L && Expr->isAffine())
5063 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5064 Valid = false;
5065 return Expr;
5066 }
5067
5068 bool isValid() { return Valid; }
5069
5070private:
5071 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5072 : SCEVRewriteVisitor(SE), L(L) {}
5073
5074 const Loop *L;
5075 bool Valid = true;
5076};
5077
5078} // end anonymous namespace
5079
5081ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5082 if (!AR->isAffine())
5083 return SCEV::FlagAnyWrap;
5084
5085 using OBO = OverflowingBinaryOperator;
5086
5088
5089 if (!AR->hasNoSelfWrap()) {
5090 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5091 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5092 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5093 const APInt &BECountAP = BECountMax->getAPInt();
5094 unsigned NoOverflowBitWidth =
5095 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5096 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5098 }
5099 }
5100
5101 if (!AR->hasNoSignedWrap()) {
5102 ConstantRange AddRecRange = getSignedRange(AR);
5103 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5104
5106 Instruction::Add, IncRange, OBO::NoSignedWrap);
5107 if (NSWRegion.contains(AddRecRange))
5109 }
5110
5111 if (!AR->hasNoUnsignedWrap()) {
5112 ConstantRange AddRecRange = getUnsignedRange(AR);
5113 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5114
5116 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5117 if (NUWRegion.contains(AddRecRange))
5119 }
5120
5121 return Result;
5122}
5123
5125ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5127
5128 if (AR->hasNoSignedWrap())
5129 return Result;
5130
5131 if (!AR->isAffine())
5132 return Result;
5133
5134 // This function can be expensive, only try to prove NSW once per AddRec.
5135 if (!SignedWrapViaInductionTried.insert(AR).second)
5136 return Result;
5137
5138 const SCEV *Step = AR->getStepRecurrence(*this);
5139 const Loop *L = AR->getLoop();
5140
5141 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5142 // Note that this serves two purposes: It filters out loops that are
5143 // simply not analyzable, and it covers the case where this code is
5144 // being called from within backedge-taken count analysis, such that
5145 // attempting to ask for the backedge-taken count would likely result
5146 // in infinite recursion. In the later case, the analysis code will
5147 // cope with a conservative value, and it will take care to purge
5148 // that value once it has finished.
5149 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5150
5151 // Normally, in the cases we can prove no-overflow via a
5152 // backedge guarding condition, we can also compute a backedge
5153 // taken count for the loop. The exceptions are assumptions and
5154 // guards present in the loop -- SCEV is not great at exploiting
5155 // these to compute max backedge taken counts, but can still use
5156 // these to prove lack of overflow. Use this fact to avoid
5157 // doing extra work that may not pay off.
5158
5159 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5160 AC.assumptions().empty())
5161 return Result;
5162
5163 // If the backedge is guarded by a comparison with the pre-inc value the
5164 // addrec is safe. Also, if the entry is guarded by a comparison with the
5165 // start value and the backedge is guarded by a comparison with the post-inc
5166 // value, the addrec is safe.
5168 const SCEV *OverflowLimit =
5169 getSignedOverflowLimitForStep(Step, &Pred, this);
5170 if (OverflowLimit &&
5171 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5172 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5173 Result = setFlags(Result, SCEV::FlagNSW);
5174 }
5175 return Result;
5176}
5178ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5180
5181 if (AR->hasNoUnsignedWrap())
5182 return Result;
5183
5184 if (!AR->isAffine())
5185 return Result;
5186
5187 // This function can be expensive, only try to prove NUW once per AddRec.
5188 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5189 return Result;
5190
5191 const SCEV *Step = AR->getStepRecurrence(*this);
5192 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5193 const Loop *L = AR->getLoop();
5194
5195 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5196 // Note that this serves two purposes: It filters out loops that are
5197 // simply not analyzable, and it covers the case where this code is
5198 // being called from within backedge-taken count analysis, such that
5199 // attempting to ask for the backedge-taken count would likely result
5200 // in infinite recursion. In the later case, the analysis code will
5201 // cope with a conservative value, and it will take care to purge
5202 // that value once it has finished.
5203 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5204
5205 // Normally, in the cases we can prove no-overflow via a
5206 // backedge guarding condition, we can also compute a backedge
5207 // taken count for the loop. The exceptions are assumptions and
5208 // guards present in the loop -- SCEV is not great at exploiting
5209 // these to compute max backedge taken counts, but can still use
5210 // these to prove lack of overflow. Use this fact to avoid
5211 // doing extra work that may not pay off.
5212
5213 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5214 AC.assumptions().empty())
5215 return Result;
5216
5217 // If the backedge is guarded by a comparison with the pre-inc value the
5218 // addrec is safe. Also, if the entry is guarded by a comparison with the
5219 // start value and the backedge is guarded by a comparison with the post-inc
5220 // value, the addrec is safe.
5221 if (isKnownPositive(Step)) {
5223 getUnsignedRangeMax(Step));
5226 Result = setFlags(Result, SCEV::FlagNUW);
5227 }
5228 }
5229
5230 return Result;
5231}
5232
5233namespace {
5234
5235/// Represents an abstract binary operation. This may exist as a
5236/// normal instruction or constant expression, or may have been
5237/// derived from an expression tree.
5238struct BinaryOp {
5239 unsigned Opcode;
5240 Value *LHS;
5241 Value *RHS;
5242 bool IsNSW = false;
5243 bool IsNUW = false;
5244
5245 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5246 /// constant expression.
5247 Operator *Op = nullptr;
5248
5249 explicit BinaryOp(Operator *Op)
5250 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5251 Op(Op) {
5252 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5253 IsNSW = OBO->hasNoSignedWrap();
5254 IsNUW = OBO->hasNoUnsignedWrap();
5255 }
5256 }
5257
5258 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5259 bool IsNUW = false)
5260 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5261};
5262
5263} // end anonymous namespace
5264
5265/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5266static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5267 AssumptionCache &AC,
5268 const DominatorTree &DT,
5269 const Instruction *CxtI) {
5270 auto *Op = dyn_cast<Operator>(V);
5271 if (!Op)
5272 return std::nullopt;
5273
5274 // Implementation detail: all the cleverness here should happen without
5275 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5276 // SCEV expressions when possible, and we should not break that.
5277
5278 switch (Op->getOpcode()) {
5279 case Instruction::Add:
5280 case Instruction::Sub:
5281 case Instruction::Mul:
5282 case Instruction::UDiv:
5283 case Instruction::URem:
5284 case Instruction::And:
5285 case Instruction::AShr:
5286 case Instruction::Shl:
5287 return BinaryOp(Op);
5288
5289 case Instruction::Or: {
5290 // Convert or disjoint into add nuw nsw.
5291 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5292 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5293 /*IsNSW=*/true, /*IsNUW=*/true);
5294 return BinaryOp(Op);
5295 }
5296
5297 case Instruction::Xor:
5298 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5299 // If the RHS of the xor is a signmask, then this is just an add.
5300 // Instcombine turns add of signmask into xor as a strength reduction step.
5301 if (RHSC->getValue().isSignMask())
5302 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5303 // Binary `xor` is a bit-wise `add`.
5304 if (V->getType()->isIntegerTy(1))
5305 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5306 return BinaryOp(Op);
5307
5308 case Instruction::LShr:
5309 // Turn logical shift right of a constant into a unsigned divide.
5310 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5311 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5312
5313 // If the shift count is not less than the bitwidth, the result of
5314 // the shift is undefined. Don't try to analyze it, because the
5315 // resolution chosen here may differ from the resolution chosen in
5316 // other parts of the compiler.
5317 if (SA->getValue().ult(BitWidth)) {
5318 Constant *X =
5319 ConstantInt::get(SA->getContext(),
5320 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5321 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5322 }
5323 }
5324 return BinaryOp(Op);
5325
5326 case Instruction::ExtractValue: {
5327 auto *EVI = cast<ExtractValueInst>(Op);
5328 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5329 break;
5330
5331 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5332 if (!WO)
5333 break;
5334
5335 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5336 bool Signed = WO->isSigned();
5337 // TODO: Should add nuw/nsw flags for mul as well.
5338 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5339 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5340
5341 // Now that we know that all uses of the arithmetic-result component of
5342 // CI are guarded by the overflow check, we can go ahead and pretend
5343 // that the arithmetic is non-overflowing.
5344 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5345 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5346 }
5347
5348 default:
5349 break;
5350 }
5351
5352 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5353 // semantics as a Sub, return a binary sub expression.
5354 if (auto *II = dyn_cast<IntrinsicInst>(V))
5355 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5356 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5357
5358 return std::nullopt;
5359}
5360
5361/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5362/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5363/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5364/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5365/// follows one of the following patterns:
5366/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5367/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5368/// If the SCEV expression of \p Op conforms with one of the expected patterns
5369/// we return the type of the truncation operation, and indicate whether the
5370/// truncated type should be treated as signed/unsigned by setting
5371/// \p Signed to true/false, respectively.
5372static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5373 bool &Signed, ScalarEvolution &SE) {
5374 // The case where Op == SymbolicPHI (that is, with no type conversions on
5375 // the way) is handled by the regular add recurrence creating logic and
5376 // would have already been triggered in createAddRecForPHI. Reaching it here
5377 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5378 // because one of the other operands of the SCEVAddExpr updating this PHI is
5379 // not invariant).
5380 //
5381 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5382 // this case predicates that allow us to prove that Op == SymbolicPHI will
5383 // be added.
5384 if (Op == SymbolicPHI)
5385 return nullptr;
5386
5387 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5388 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5389 if (SourceBits != NewBits)
5390 return nullptr;
5391
5392 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5393 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5394 if (!SExt && !ZExt)
5395 return nullptr;
5396 const SCEVTruncateExpr *Trunc =
5397 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5398 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5399 if (!Trunc)
5400 return nullptr;
5401 const SCEV *X = Trunc->getOperand();
5402 if (X != SymbolicPHI)
5403 return nullptr;
5404 Signed = SExt != nullptr;
5405 return Trunc->getType();
5406}
5407
5408static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5409 if (!PN->getType()->isIntegerTy())
5410 return nullptr;
5411 const Loop *L = LI.getLoopFor(PN->getParent());
5412 if (!L || L->getHeader() != PN->getParent())
5413 return nullptr;
5414 return L;
5415}
5416
5417// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5418// computation that updates the phi follows the following pattern:
5419// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5420// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5421// If so, try to see if it can be rewritten as an AddRecExpr under some
5422// Predicates. If successful, return them as a pair. Also cache the results
5423// of the analysis.
5424//
5425// Example usage scenario:
5426// Say the Rewriter is called for the following SCEV:
5427// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5428// where:
5429// %X = phi i64 (%Start, %BEValue)
5430// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5431// and call this function with %SymbolicPHI = %X.
5432//
5433// The analysis will find that the value coming around the backedge has
5434// the following SCEV:
5435// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5436// Upon concluding that this matches the desired pattern, the function
5437// will return the pair {NewAddRec, SmallPredsVec} where:
5438// NewAddRec = {%Start,+,%Step}
5439// SmallPredsVec = {P1, P2, P3} as follows:
5440// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5441// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5442// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5443// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5444// under the predicates {P1,P2,P3}.
5445// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5446// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5447//
5448// TODO's:
5449//
5450// 1) Extend the Induction descriptor to also support inductions that involve
5451// casts: When needed (namely, when we are called in the context of the
5452// vectorizer induction analysis), a Set of cast instructions will be
5453// populated by this method, and provided back to isInductionPHI. This is
5454// needed to allow the vectorizer to properly record them to be ignored by
5455// the cost model and to avoid vectorizing them (otherwise these casts,
5456// which are redundant under the runtime overflow checks, will be
5457// vectorized, which can be costly).
5458//
5459// 2) Support additional induction/PHISCEV patterns: We also want to support
5460// inductions where the sext-trunc / zext-trunc operations (partly) occur
5461// after the induction update operation (the induction increment):
5462//
5463// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5464// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5465//
5466// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5467// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5468//
5469// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5470std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5471ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5473
5474 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5475 // return an AddRec expression under some predicate.
5476
5477 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5478 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5479 assert(L && "Expecting an integer loop header phi");
5480
5481 // The loop may have multiple entrances or multiple exits; we can analyze
5482 // this phi as an addrec if it has a unique entry value and a unique
5483 // backedge value.
5484 Value *BEValueV = nullptr, *StartValueV = nullptr;
5485 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5486 Value *V = PN->getIncomingValue(i);
5487 if (L->contains(PN->getIncomingBlock(i))) {
5488 if (!BEValueV) {
5489 BEValueV = V;
5490 } else if (BEValueV != V) {
5491 BEValueV = nullptr;
5492 break;
5493 }
5494 } else if (!StartValueV) {
5495 StartValueV = V;
5496 } else if (StartValueV != V) {
5497 StartValueV = nullptr;
5498 break;
5499 }
5500 }
5501 if (!BEValueV || !StartValueV)
5502 return std::nullopt;
5503
5504 const SCEV *BEValue = getSCEV(BEValueV);
5505
5506 // If the value coming around the backedge is an add with the symbolic
5507 // value we just inserted, possibly with casts that we can ignore under
5508 // an appropriate runtime guard, then we found a simple induction variable!
5509 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5510 if (!Add)
5511 return std::nullopt;
5512
5513 // If there is a single occurrence of the symbolic value, possibly
5514 // casted, replace it with a recurrence.
5515 unsigned FoundIndex = Add->getNumOperands();
5516 Type *TruncTy = nullptr;
5517 bool Signed;
5518 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5519 if ((TruncTy =
5520 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5521 if (FoundIndex == e) {
5522 FoundIndex = i;
5523 break;
5524 }
5525
5526 if (FoundIndex == Add->getNumOperands())
5527 return std::nullopt;
5528
5529 // Create an add with everything but the specified operand.
5531 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5532 if (i != FoundIndex)
5533 Ops.push_back(Add->getOperand(i));
5534 const SCEV *Accum = getAddExpr(Ops);
5535
5536 // The runtime checks will not be valid if the step amount is
5537 // varying inside the loop.
5538 if (!isLoopInvariant(Accum, L))
5539 return std::nullopt;
5540
5541 // *** Part2: Create the predicates
5542
5543 // Analysis was successful: we have a phi-with-cast pattern for which we
5544 // can return an AddRec expression under the following predicates:
5545 //
5546 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5547 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5548 // P2: An Equal predicate that guarantees that
5549 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5550 // P3: An Equal predicate that guarantees that
5551 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5552 //
5553 // As we next prove, the above predicates guarantee that:
5554 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5555 //
5556 //
5557 // More formally, we want to prove that:
5558 // Expr(i+1) = Start + (i+1) * Accum
5559 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5560 //
5561 // Given that:
5562 // 1) Expr(0) = Start
5563 // 2) Expr(1) = Start + Accum
5564 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5565 // 3) Induction hypothesis (step i):
5566 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5567 //
5568 // Proof:
5569 // Expr(i+1) =
5570 // = Start + (i+1)*Accum
5571 // = (Start + i*Accum) + Accum
5572 // = Expr(i) + Accum
5573 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5574 // :: from step i
5575 //
5576 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5577 //
5578 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5579 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5580 // + Accum :: from P3
5581 //
5582 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5583 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5584 //
5585 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5586 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5587 //
5588 // By induction, the same applies to all iterations 1<=i<n:
5589 //
5590
5591 // Create a truncated addrec for which we will add a no overflow check (P1).
5592 const SCEV *StartVal = getSCEV(StartValueV);
5593 const SCEV *PHISCEV =
5594 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5595 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5596
5597 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5598 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5599 // will be constant.
5600 //
5601 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5602 // add P1.
5603 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5607 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5608 Predicates.push_back(AddRecPred);
5609 }
5610
5611 // Create the Equal Predicates P2,P3:
5612
5613 // It is possible that the predicates P2 and/or P3 are computable at
5614 // compile time due to StartVal and/or Accum being constants.
5615 // If either one is, then we can check that now and escape if either P2
5616 // or P3 is false.
5617
5618 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5619 // for each of StartVal and Accum
5620 auto getExtendedExpr = [&](const SCEV *Expr,
5621 bool CreateSignExtend) -> const SCEV * {
5622 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5623 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5624 const SCEV *ExtendedExpr =
5625 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5626 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5627 return ExtendedExpr;
5628 };
5629
5630 // Given:
5631 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5632 // = getExtendedExpr(Expr)
5633 // Determine whether the predicate P: Expr == ExtendedExpr
5634 // is known to be false at compile time
5635 auto PredIsKnownFalse = [&](const SCEV *Expr,
5636 const SCEV *ExtendedExpr) -> bool {
5637 return Expr != ExtendedExpr &&
5638 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5639 };
5640
5641 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5642 if (PredIsKnownFalse(StartVal, StartExtended)) {
5643 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5644 return std::nullopt;
5645 }
5646
5647 // The Step is always Signed (because the overflow checks are either
5648 // NSSW or NUSW)
5649 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5650 if (PredIsKnownFalse(Accum, AccumExtended)) {
5651 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5652 return std::nullopt;
5653 }
5654
5655 auto AppendPredicate = [&](const SCEV *Expr,
5656 const SCEV *ExtendedExpr) -> void {
5657 if (Expr != ExtendedExpr &&
5658 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5659 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5660 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5661 Predicates.push_back(Pred);
5662 }
5663 };
5664
5665 AppendPredicate(StartVal, StartExtended);
5666 AppendPredicate(Accum, AccumExtended);
5667
5668 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5669 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5670 // into NewAR if it will also add the runtime overflow checks specified in
5671 // Predicates.
5672 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5673
5674 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5675 std::make_pair(NewAR, Predicates);
5676 // Remember the result of the analysis for this SCEV at this locayyytion.
5677 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5678 return PredRewrite;
5679}
5680
5681std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5683 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5684 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5685 if (!L)
5686 return std::nullopt;
5687
5688 // Check to see if we already analyzed this PHI.
5689 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5690 if (I != PredicatedSCEVRewrites.end()) {
5691 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5692 I->second;
5693 // Analysis was done before and failed to create an AddRec:
5694 if (Rewrite.first == SymbolicPHI)
5695 return std::nullopt;
5696 // Analysis was done before and succeeded to create an AddRec under
5697 // a predicate:
5698 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5699 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5700 return Rewrite;
5701 }
5702
5703 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5704 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5705
5706 // Record in the cache that the analysis failed
5707 if (!Rewrite) {
5709 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5710 return std::nullopt;
5711 }
5712
5713 return Rewrite;
5714}
5715
5716// FIXME: This utility is currently required because the Rewriter currently
5717// does not rewrite this expression:
5718// {0, +, (sext ix (trunc iy to ix) to iy)}
5719// into {0, +, %step},
5720// even when the following Equal predicate exists:
5721// "%step == (sext ix (trunc iy to ix) to iy)".
5723 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5724 if (AR1 == AR2)
5725 return true;
5726
5727 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5728 if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5729 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5730 return false;
5731 return true;
5732 };
5733
5734 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5735 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5736 return false;
5737 return true;
5738}
5739
5740/// A helper function for createAddRecFromPHI to handle simple cases.
5741///
5742/// This function tries to find an AddRec expression for the simplest (yet most
5743/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5744/// If it fails, createAddRecFromPHI will use a more general, but slow,
5745/// technique for finding the AddRec expression.
5746const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5747 Value *BEValueV,
5748 Value *StartValueV) {
5749 const Loop *L = LI.getLoopFor(PN->getParent());
5750 assert(L && L->getHeader() == PN->getParent());
5751 assert(BEValueV && StartValueV);
5752
5753 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5754 if (!BO)
5755 return nullptr;
5756
5757 if (BO->Opcode != Instruction::Add)
5758 return nullptr;
5759
5760 const SCEV *Accum = nullptr;
5761 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5762 Accum = getSCEV(BO->RHS);
5763 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5764 Accum = getSCEV(BO->LHS);
5765
5766 if (!Accum)
5767 return nullptr;
5768
5770 if (BO->IsNUW)
5771 Flags = setFlags(Flags, SCEV::FlagNUW);
5772 if (BO->IsNSW)
5773 Flags = setFlags(Flags, SCEV::FlagNSW);
5774
5775 const SCEV *StartVal = getSCEV(StartValueV);
5776 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5777 insertValueToMap(PN, PHISCEV);
5778
5779 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5780 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5782 proveNoWrapViaConstantRanges(AR)));
5783 }
5784
5785 // We can add Flags to the post-inc expression only if we
5786 // know that it is *undefined behavior* for BEValueV to
5787 // overflow.
5788 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5789 assert(isLoopInvariant(Accum, L) &&
5790 "Accum is defined outside L, but is not invariant?");
5791 if (isAddRecNeverPoison(BEInst, L))
5792 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5793 }
5794
5795 return PHISCEV;
5796}
5797
5798const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5799 const Loop *L = LI.getLoopFor(PN->getParent());
5800 if (!L || L->getHeader() != PN->getParent())
5801 return nullptr;
5802
5803 // The loop may have multiple entrances or multiple exits; we can analyze
5804 // this phi as an addrec if it has a unique entry value and a unique
5805 // backedge value.
5806 Value *BEValueV = nullptr, *StartValueV = nullptr;
5807 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5808 Value *V = PN->getIncomingValue(i);
5809 if (L->contains(PN->getIncomingBlock(i))) {
5810 if (!BEValueV) {
5811 BEValueV = V;
5812 } else if (BEValueV != V) {
5813 BEValueV = nullptr;
5814 break;
5815 }
5816 } else if (!StartValueV) {
5817 StartValueV = V;
5818 } else if (StartValueV != V) {
5819 StartValueV = nullptr;
5820 break;
5821 }
5822 }
5823 if (!BEValueV || !StartValueV)
5824 return nullptr;
5825
5826 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5827 "PHI node already processed?");
5828
5829 // First, try to find AddRec expression without creating a fictituos symbolic
5830 // value for PN.
5831 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5832 return S;
5833
5834 // Handle PHI node value symbolically.
5835 const SCEV *SymbolicName = getUnknown(PN);
5836 insertValueToMap(PN, SymbolicName);
5837
5838 // Using this symbolic name for the PHI, analyze the value coming around
5839 // the back-edge.
5840 const SCEV *BEValue = getSCEV(BEValueV);
5841
5842 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5843 // has a special value for the first iteration of the loop.
5844
5845 // If the value coming around the backedge is an add with the symbolic
5846 // value we just inserted, then we found a simple induction variable!
5847 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5848 // If there is a single occurrence of the symbolic value, replace it
5849 // with a recurrence.
5850 unsigned FoundIndex = Add->getNumOperands();
5851 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5852 if (Add->getOperand(i) == SymbolicName)
5853 if (FoundIndex == e) {
5854 FoundIndex = i;
5855 break;
5856 }
5857
5858 if (FoundIndex != Add->getNumOperands()) {
5859 // Create an add with everything but the specified operand.
5861 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5862 if (i != FoundIndex)
5863 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5864 L, *this));
5865 const SCEV *Accum = getAddExpr(Ops);
5866
5867 // This is not a valid addrec if the step amount is varying each
5868 // loop iteration, but is not itself an addrec in this loop.
5869 if (isLoopInvariant(Accum, L) ||
5870 (isa<SCEVAddRecExpr>(Accum) &&
5871 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5873
5874 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5875 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5876 if (BO->IsNUW)
5877 Flags = setFlags(Flags, SCEV::FlagNUW);
5878 if (BO->IsNSW)
5879 Flags = setFlags(Flags, SCEV::FlagNSW);
5880 }
5881 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5882 // If the increment is an inbounds GEP, then we know the address
5883 // space cannot be wrapped around. We cannot make any guarantee
5884 // about signed or unsigned overflow because pointers are
5885 // unsigned but we may have a negative index from the base
5886 // pointer. We can guarantee that no unsigned wrap occurs if the
5887 // indices form a positive value.
5888 if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
5889 Flags = setFlags(Flags, SCEV::FlagNW);
5890 if (isKnownPositive(Accum))
5891 Flags = setFlags(Flags, SCEV::FlagNUW);
5892 }
5893
5894 // We cannot transfer nuw and nsw flags from subtraction
5895 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5896 // for instance.
5897 }
5898
5899 const SCEV *StartVal = getSCEV(StartValueV);
5900 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5901
5902 // Okay, for the entire analysis of this edge we assumed the PHI
5903 // to be symbolic. We now need to go back and purge all of the
5904 // entries for the scalars that use the symbolic expression.
5905 forgetMemoizedResults(SymbolicName);
5906 insertValueToMap(PN, PHISCEV);
5907
5908 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5909 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5911 proveNoWrapViaConstantRanges(AR)));
5912 }
5913
5914 // We can add Flags to the post-inc expression only if we
5915 // know that it is *undefined behavior* for BEValueV to
5916 // overflow.
5917 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5918 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5919 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5920
5921 return PHISCEV;
5922 }
5923 }
5924 } else {
5925 // Otherwise, this could be a loop like this:
5926 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5927 // In this case, j = {1,+,1} and BEValue is j.
5928 // Because the other in-value of i (0) fits the evolution of BEValue
5929 // i really is an addrec evolution.
5930 //
5931 // We can generalize this saying that i is the shifted value of BEValue
5932 // by one iteration:
5933 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5934 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5935 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5936 if (Shifted != getCouldNotCompute() &&
5937 Start != getCouldNotCompute()) {
5938 const SCEV *StartVal = getSCEV(StartValueV);
5939 if (Start == StartVal) {
5940 // Okay, for the entire analysis of this edge we assumed the PHI
5941 // to be symbolic. We now need to go back and purge all of the
5942 // entries for the scalars that use the symbolic expression.
5943 forgetMemoizedResults(SymbolicName);
5944 insertValueToMap(PN, Shifted);
5945 return Shifted;
5946 }
5947 }
5948 }
5949
5950 // Remove the temporary PHI node SCEV that has been inserted while intending
5951 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5952 // as it will prevent later (possibly simpler) SCEV expressions to be added
5953 // to the ValueExprMap.
5954 eraseValueFromMap(PN);
5955
5956 return nullptr;
5957}
5958
5959// Try to match a control flow sequence that branches out at BI and merges back
5960// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5961// match.
5963 Value *&C, Value *&LHS, Value *&RHS) {
5964 C = BI->getCondition();
5965
5966 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5967 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5968
5969 if (!LeftEdge.isSingleEdge())
5970 return false;
5971
5972 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5973
5974 Use &LeftUse = Merge->getOperandUse(0);
5975 Use &RightUse = Merge->getOperandUse(1);
5976
5977 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5978 LHS = LeftUse;
5979 RHS = RightUse;
5980 return true;
5981 }
5982
5983 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5984 LHS = RightUse;
5985 RHS = LeftUse;
5986 return true;
5987 }
5988
5989 return false;
5990}
5991
5992const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5993 auto IsReachable =
5994 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5995 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5996 // Try to match
5997 //
5998 // br %cond, label %left, label %right
5999 // left:
6000 // br label %merge
6001 // right:
6002 // br label %merge
6003 // merge:
6004 // V = phi [ %x, %left ], [ %y, %right ]
6005 //
6006 // as "select %cond, %x, %y"
6007
6008 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6009 assert(IDom && "At least the entry block should dominate PN");
6010
6011 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6012 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6013
6014 if (BI && BI->isConditional() &&
6015 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6016 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6017 properlyDominates(getSCEV(RHS), PN->getParent()))
6018 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6019 }
6020
6021 return nullptr;
6022}
6023
6024const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6025 if (const SCEV *S = createAddRecFromPHI(PN))
6026 return S;
6027
6028 if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
6029 return getSCEV(V);
6030
6031 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6032 return S;
6033
6034 // If it's not a loop phi, we can't handle it yet.
6035 return getUnknown(PN);
6036}
6037
6038bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6039 SCEVTypes RootKind) {
6040 struct FindClosure {
6041 const SCEV *OperandToFind;
6042 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6043 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6044
6045 bool Found = false;
6046
6047 bool canRecurseInto(SCEVTypes Kind) const {
6048 // We can only recurse into the SCEV expression of the same effective type
6049 // as the type of our root SCEV expression, and into zero-extensions.
6050 return RootKind == Kind || NonSequentialRootKind == Kind ||
6051 scZeroExtend == Kind;
6052 };
6053
6054 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6055 : OperandToFind(OperandToFind), RootKind(RootKind),
6056 NonSequentialRootKind(
6058 RootKind)) {}
6059
6060 bool follow(const SCEV *S) {
6061 Found = S == OperandToFind;
6062
6063 return !isDone() && canRecurseInto(S->getSCEVType());
6064 }
6065
6066 bool isDone() const { return Found; }
6067 };
6068
6069 FindClosure FC(OperandToFind, RootKind);
6070 visitAll(Root, FC);
6071 return FC.Found;
6072}
6073
6074std::optional<const SCEV *>
6075ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6076 ICmpInst *Cond,
6077 Value *TrueVal,
6078 Value *FalseVal) {
6079 // Try to match some simple smax or umax patterns.
6080 auto *ICI = Cond;
6081
6082 Value *LHS = ICI->getOperand(0);
6083 Value *RHS = ICI->getOperand(1);
6084
6085 switch (ICI->getPredicate()) {
6086 case ICmpInst::ICMP_SLT:
6087 case ICmpInst::ICMP_SLE:
6088 case ICmpInst::ICMP_ULT:
6089 case ICmpInst::ICMP_ULE:
6090 std::swap(LHS, RHS);
6091 [[fallthrough]];
6092 case ICmpInst::ICMP_SGT:
6093 case ICmpInst::ICMP_SGE:
6094 case ICmpInst::ICMP_UGT:
6095 case ICmpInst::ICMP_UGE:
6096 // a > b ? a+x : b+x -> max(a, b)+x
6097 // a > b ? b+x : a+x -> min(a, b)+x
6099 bool Signed = ICI->isSigned();
6100 const SCEV *LA = getSCEV(TrueVal);
6101 const SCEV *RA = getSCEV(FalseVal);
6102 const SCEV *LS = getSCEV(LHS);
6103 const SCEV *RS = getSCEV(RHS);
6104 if (LA->getType()->isPointerTy()) {
6105 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6106 // Need to make sure we can't produce weird expressions involving
6107 // negated pointers.
6108 if (LA == LS && RA == RS)
6109 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6110 if (LA == RS && RA == LS)
6111 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6112 }
6113 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6114 if (Op->getType()->isPointerTy()) {
6116 if (isa<SCEVCouldNotCompute>(Op))
6117 return Op;
6118 }
6119 if (Signed)
6120 Op = getNoopOrSignExtend(Op, Ty);
6121 else
6122 Op = getNoopOrZeroExtend(Op, Ty);
6123 return Op;
6124 };
6125 LS = CoerceOperand(LS);
6126 RS = CoerceOperand(RS);
6127 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6128 break;
6129 const SCEV *LDiff = getMinusSCEV(LA, LS);
6130 const SCEV *RDiff = getMinusSCEV(RA, RS);
6131 if (LDiff == RDiff)
6132 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6133 LDiff);
6134 LDiff = getMinusSCEV(LA, RS);
6135 RDiff = getMinusSCEV(RA, LS);
6136 if (LDiff == RDiff)
6137 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6138 LDiff);
6139 }
6140 break;
6141 case ICmpInst::ICMP_NE:
6142 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6143 std::swap(TrueVal, FalseVal);
6144 [[fallthrough]];
6145 case ICmpInst::ICMP_EQ:
6146 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6148 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6149 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6150 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6151 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6152 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6153 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6154 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6155 return getAddExpr(getUMaxExpr(X, C), Y);
6156 }
6157 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6158 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6159 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6160 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6161 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6162 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6163 const SCEV *X = getSCEV(LHS);
6164 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6165 X = ZExt->getOperand();
6166 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6167 const SCEV *FalseValExpr = getSCEV(FalseVal);
6168 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6169 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6170 /*Sequential=*/true);
6171 }
6172 }
6173 break;
6174 default:
6175 break;
6176 }
6177
6178 return std::nullopt;
6179}
6180
6181static std::optional<const SCEV *>
6183 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6184 assert(CondExpr->getType()->isIntegerTy(1) &&
6185 TrueExpr->getType() == FalseExpr->getType() &&
6186 TrueExpr->getType()->isIntegerTy(1) &&
6187 "Unexpected operands of a select.");
6188
6189 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6190 // --> C + (umin_seq cond, x - C)
6191 //
6192 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6193 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6194 // --> C + (umin_seq ~cond, x - C)
6195
6196 // FIXME: while we can't legally model the case where both of the hands
6197 // are fully variable, we only require that the *difference* is constant.
6198 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6199 return std::nullopt;
6200
6201 const SCEV *X, *C;
6202 if (isa<SCEVConstant>(TrueExpr)) {
6203 CondExpr = SE->getNotSCEV(CondExpr);
6204 X = FalseExpr;
6205 C = TrueExpr;
6206 } else {
6207 X = TrueExpr;
6208 C = FalseExpr;
6209 }
6210 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6211 /*Sequential=*/true));
6212}
6213
6214static std::optional<const SCEV *>
6216 Value *FalseVal) {
6217 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6218 return std::nullopt;
6219
6220 const auto *SECond = SE->getSCEV(Cond);
6221 const auto *SETrue = SE->getSCEV(TrueVal);
6222 const auto *SEFalse = SE->getSCEV(FalseVal);
6223 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6224}
6225
6226const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6227 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6228 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6229 assert(TrueVal->getType() == FalseVal->getType() &&
6230 V->getType() == TrueVal->getType() &&
6231 "Types of select hands and of the result must match.");
6232
6233 // For now, only deal with i1-typed `select`s.
6234 if (!V->getType()->isIntegerTy(1))
6235 return getUnknown(V);
6236
6237 if (std::optional<const SCEV *> S =
6238 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6239 return *S;
6240
6241 return getUnknown(V);
6242}
6243
6244const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6245 Value *TrueVal,
6246 Value *FalseVal) {
6247 // Handle "constant" branch or select. This can occur for instance when a
6248 // loop pass transforms an inner loop and moves on to process the outer loop.
6249 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6250 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6251
6252 if (auto *I = dyn_cast<Instruction>(V)) {
6253 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6254 if (std::optional<const SCEV *> S =
6255 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6256 TrueVal, FalseVal))
6257 return *S;
6258 }
6259 }
6260
6261 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6262}
6263
6264/// Expand GEP instructions into add and multiply operations. This allows them
6265/// to be analyzed by regular SCEV code.
6266const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6267 assert(GEP->getSourceElementType()->isSized() &&
6268 "GEP source element type must be sized");
6269
6271 for (Value *Index : GEP->indices())
6272 IndexExprs.push_back(getSCEV(Index));
6273 return getGEPExpr(GEP, IndexExprs);
6274}
6275
6276APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6278 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6279 return TrailingZeros >= BitWidth
6281 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6282 };
6283 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6284 // The result is GCD of all operands results.
6285 APInt Res = getConstantMultiple(N->getOperand(0));
6286 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6288 Res, getConstantMultiple(N->getOperand(I)));
6289 return Res;
6290 };
6291
6292 switch (S->getSCEVType()) {
6293 case scConstant:
6294 return cast<SCEVConstant>(S)->getAPInt();
6295 case scPtrToInt:
6296 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6297 case scUDivExpr:
6298 case scVScale:
6299 return APInt(BitWidth, 1);
6300 case scTruncate: {
6301 // Only multiples that are a power of 2 will hold after truncation.
6302 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6303 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6304 return GetShiftedByZeros(TZ);
6305 }
6306 case scZeroExtend: {
6307 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6308 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6309 }
6310 case scSignExtend: {
6311 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6313 }
6314 case scMulExpr: {
6315 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6316 if (M->hasNoUnsignedWrap()) {
6317 // The result is the product of all operand results.
6318 APInt Res = getConstantMultiple(M->getOperand(0));
6319 for (const SCEV *Operand : M->operands().drop_front())
6320 Res = Res * getConstantMultiple(Operand);
6321 return Res;
6322 }
6323
6324 // If there are no wrap guarentees, find the trailing zeros, which is the
6325 // sum of trailing zeros for all its operands.
6326 uint32_t TZ = 0;
6327 for (const SCEV *Operand : M->operands())
6328 TZ += getMinTrailingZeros(Operand);
6329 return GetShiftedByZeros(TZ);
6330 }
6331 case scAddExpr:
6332 case scAddRecExpr: {
6333 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6334 if (N->hasNoUnsignedWrap())
6335 return GetGCDMultiple(N);
6336 // Find the trailing bits, which is the minimum of its operands.
6337 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6338 for (const SCEV *Operand : N->operands().drop_front())
6339 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6340 return GetShiftedByZeros(TZ);
6341 }
6342 case scUMaxExpr:
6343 case scSMaxExpr:
6344 case scUMinExpr:
6345 case scSMinExpr:
6347 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6348 case scUnknown: {
6349 // ask ValueTracking for known bits
6350 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6351 unsigned Known =
6352 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6353 .countMinTrailingZeros();
6354 return GetShiftedByZeros(Known);
6355 }
6356 case scCouldNotCompute:
6357 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6358 }
6359 llvm_unreachable("Unknown SCEV kind!");
6360}
6361
6363 auto I = ConstantMultipleCache.find(S);
6364 if (I != ConstantMultipleCache.end())
6365 return I->second;
6366
6367 APInt Result = getConstantMultipleImpl(S);
6368 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6369 assert(InsertPair.second && "Should insert a new key");
6370 return InsertPair.first->second;
6371}
6372
6374 APInt Multiple = getConstantMultiple(S);
6375 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6376}
6377
6379 return std::min(getConstantMultiple(S).countTrailingZeros(),
6380 (unsigned)getTypeSizeInBits(S->getType()));
6381}
6382
6383/// Helper method to assign a range to V from metadata present in the IR.
6384static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6385 if (Instruction *I = dyn_cast<Instruction>(V)) {
6386 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6387 return getConstantRangeFromMetadata(*MD);
6388 if (const auto *CB = dyn_cast<CallBase>(V))
6389 if (std::optional<ConstantRange> Range = CB->getRange())
6390 return Range;
6391 }
6392 if (auto *A = dyn_cast<Argument>(V))
6393 if (std::optional<ConstantRange> Range = A->getRange())
6394 return Range;
6395
6396 return std::nullopt;
6397}
6398
6400 SCEV::NoWrapFlags Flags) {
6401 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6402 AddRec->setNoWrapFlags(Flags);
6403 UnsignedRanges.erase(AddRec);
6404 SignedRanges.erase(AddRec);
6405 ConstantMultipleCache.erase(AddRec);
6406 }
6407}
6408
6409ConstantRange ScalarEvolution::
6410getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6411 const DataLayout &DL = getDataLayout();
6412
6413 unsigned BitWidth = getTypeSizeInBits(U->getType());
6414 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6415
6416 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6417 // use information about the trip count to improve our available range. Note
6418 // that the trip count independent cases are already handled by known bits.
6419 // WARNING: The definition of recurrence used here is subtly different than
6420 // the one used by AddRec (and thus most of this file). Step is allowed to
6421 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6422 // and other addrecs in the same loop (for non-affine addrecs). The code
6423 // below intentionally handles the case where step is not loop invariant.
6424 auto *P = dyn_cast<PHINode>(U->getValue());
6425 if (!P)
6426 return FullSet;
6427
6428 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6429 // even the values that are not available in these blocks may come from them,
6430 // and this leads to false-positive recurrence test.
6431 for (auto *Pred : predecessors(P->getParent()))
6432 if (!DT.isReachableFromEntry(Pred))
6433 return FullSet;
6434
6435 BinaryOperator *BO;
6436 Value *Start, *Step;
6437 if (!matchSimpleRecurrence(P, BO, Start, Step))
6438 return FullSet;
6439
6440 // If we found a recurrence in reachable code, we must be in a loop. Note
6441 // that BO might be in some subloop of L, and that's completely okay.
6442 auto *L = LI.getLoopFor(P->getParent());
6443 assert(L && L->getHeader() == P->getParent());
6444 if (!L->contains(BO->getParent()))
6445 // NOTE: This bailout should be an assert instead. However, asserting
6446 // the condition here exposes a case where LoopFusion is querying SCEV
6447 // with malformed loop information during the midst of the transform.
6448 // There doesn't appear to be an obvious fix, so for the moment bailout
6449 // until the caller issue can be fixed. PR49566 tracks the bug.
6450 return FullSet;
6451
6452 // TODO: Extend to other opcodes such as mul, and div
6453 switch (BO->getOpcode()) {
6454 default:
6455 return FullSet;
6456 case Instruction::AShr:
6457 case Instruction::LShr:
6458 case Instruction::Shl:
6459 break;
6460 };
6461
6462 if (BO->getOperand(0) != P)
6463 // TODO: Handle the power function forms some day.
6464 return FullSet;
6465
6466 unsigned TC = getSmallConstantMaxTripCount(L);
6467 if (!TC || TC >= BitWidth)
6468 return FullSet;
6469
6470 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6471 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6472 assert(KnownStart.getBitWidth() == BitWidth &&
6473 KnownStep.getBitWidth() == BitWidth);
6474
6475 // Compute total shift amount, being careful of overflow and bitwidths.
6476 auto MaxShiftAmt = KnownStep.getMaxValue();
6477 APInt TCAP(BitWidth, TC-1);
6478 bool Overflow = false;
6479 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6480 if (Overflow)
6481 return FullSet;
6482
6483 switch (BO->getOpcode()) {
6484 default:
6485 llvm_unreachable("filtered out above");
6486 case Instruction::AShr: {
6487 // For each ashr, three cases:
6488 // shift = 0 => unchanged value
6489 // saturation => 0 or -1
6490 // other => a value closer to zero (of the same sign)
6491 // Thus, the end value is closer to zero than the start.
6492 auto KnownEnd = KnownBits::ashr(KnownStart,
6493 KnownBits::makeConstant(TotalShift));
6494 if (KnownStart.isNonNegative())
6495 // Analogous to lshr (simply not yet canonicalized)
6496 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6497 KnownStart.getMaxValue() + 1);
6498 if (KnownStart.isNegative())
6499 // End >=u Start && End <=s Start
6500 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6501 KnownEnd.getMaxValue() + 1);
6502 break;
6503 }
6504 case Instruction::LShr: {
6505 // For each lshr, three cases:
6506 // shift = 0 => unchanged value
6507 // saturation => 0
6508 // other => a smaller positive number
6509 // Thus, the low end of the unsigned range is the last value produced.
6510 auto KnownEnd = KnownBits::lshr(KnownStart,
6511 KnownBits::makeConstant(TotalShift));
6512 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6513 KnownStart.getMaxValue() + 1);
6514 }
6515 case Instruction::Shl: {
6516 // Iff no bits are shifted out, value increases on every shift.
6517 auto KnownEnd = KnownBits::shl(KnownStart,
6518 KnownBits::makeConstant(TotalShift));
6519 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6520 return ConstantRange(KnownStart.getMinValue(),
6521 KnownEnd.getMaxValue() + 1);
6522 break;
6523 }
6524 };
6525 return FullSet;
6526}
6527
6528const ConstantRange &
6529ScalarEvolution::getRangeRefIter(const SCEV *S,
6530 ScalarEvolution::RangeSignHint SignHint) {
6532 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6533 : SignedRanges;
6536
6537 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6538 // SCEVUnknown PHI node.
6539 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6540 if (!Seen.insert(Expr).second)
6541 return;
6542 if (Cache.contains(Expr))
6543 return;
6544 switch (Expr->getSCEVType()) {
6545 case scUnknown:
6546 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6547 break;
6548 [[fallthrough]];
6549 case scConstant:
6550 case scVScale:
6551 case scTruncate:
6552 case scZeroExtend:
6553 case scSignExtend:
6554 case scPtrToInt:
6555 case scAddExpr:
6556 case scMulExpr:
6557 case scUDivExpr:
6558 case scAddRecExpr:
6559 case scUMaxExpr:
6560 case scSMaxExpr:
6561 case scUMinExpr:
6562 case scSMinExpr:
6564 WorkList.push_back(Expr);
6565 break;
6566 case scCouldNotCompute:
6567 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6568 }
6569 };
6570 AddToWorklist(S);
6571
6572 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6573 for (unsigned I = 0; I != WorkList.size(); ++I) {
6574 const SCEV *P = WorkList[I];
6575 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6576 // If it is not a `SCEVUnknown`, just recurse into operands.
6577 if (!UnknownS) {
6578 for (const SCEV *Op : P->operands())
6579 AddToWorklist(Op);
6580 continue;
6581 }
6582 // `SCEVUnknown`'s require special treatment.
6583 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6584 if (!PendingPhiRangesIter.insert(P).second)
6585 continue;
6586 for (auto &Op : reverse(P->operands()))
6587 AddToWorklist(getSCEV(Op));
6588 }
6589 }
6590
6591 if (!WorkList.empty()) {
6592 // Use getRangeRef to compute ranges for items in the worklist in reverse
6593 // order. This will force ranges for earlier operands to be computed before
6594 // their users in most cases.
6595 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6596 getRangeRef(P, SignHint);
6597
6598 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6599 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6600 PendingPhiRangesIter.erase(P);
6601 }
6602 }
6603
6604 return getRangeRef(S, SignHint, 0);
6605}
6606
6607/// Determine the range for a particular SCEV. If SignHint is
6608/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6609/// with a "cleaner" unsigned (resp. signed) representation.
6610const ConstantRange &ScalarEvolution::getRangeRef(
6611 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6613 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6614 : SignedRanges;
6616 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6618
6619 // See if we've computed this range already.
6621 if (I != Cache.end())
6622 return I->second;
6623
6624 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6625 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6626
6627 // Switch to iteratively computing the range for S, if it is part of a deeply
6628 // nested expression.
6630 return getRangeRefIter(S, SignHint);
6631
6632 unsigned BitWidth = getTypeSizeInBits(S->getType());
6633 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6634 using OBO = OverflowingBinaryOperator;
6635
6636 // If the value has known zeros, the maximum value will have those known zeros
6637 // as well.
6638 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6639 APInt Multiple = getNonZeroConstantMultiple(S);
6640 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6641 if (!Remainder.isZero())
6642 ConservativeResult =
6644 APInt::getMaxValue(BitWidth) - Remainder + 1);
6645 }
6646 else {
6648 if (TZ != 0) {
6649 ConservativeResult = ConstantRange(
6651 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6652 }
6653 }
6654
6655 switch (S->getSCEVType()) {
6656 case scConstant:
6657 llvm_unreachable("Already handled above.");
6658 case scVScale:
6659 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6660 case scTruncate: {
6661 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6662 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6663 return setRange(
6664 Trunc, SignHint,
6665 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6666 }
6667 case scZeroExtend: {
6668 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6669 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6670 return setRange(
6671 ZExt, SignHint,
6672 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6673 }
6674 case scSignExtend: {
6675 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6676 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6677 return setRange(
6678 SExt, SignHint,
6679 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6680 }
6681 case scPtrToInt: {
6682 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6683 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6684 return setRange(PtrToInt, SignHint, X);
6685 }
6686 case scAddExpr: {
6687 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6688 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6689 unsigned WrapType = OBO::AnyWrap;
6690 if (Add->hasNoSignedWrap())
6691 WrapType |= OBO::NoSignedWrap;
6692 if (Add->hasNoUnsignedWrap())
6693 WrapType |= OBO::NoUnsignedWrap;
6694 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
6695 X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint, Depth + 1),
6696 WrapType, RangeType);
6697 return setRange(Add, SignHint,
6698 ConservativeResult.intersectWith(X, RangeType));
6699 }
6700 case scMulExpr: {
6701 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6702 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6703 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
6704 X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1));
6705 return setRange(Mul, SignHint,
6706 ConservativeResult.intersectWith(X, RangeType));
6707 }
6708 case scUDivExpr: {
6709 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6710 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6711 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6712 return setRange(UDiv, SignHint,
6713 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6714 }
6715 case scAddRecExpr: {
6716 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6717 // If there's no unsigned wrap, the value will never be less than its
6718 // initial value.
6719 if (AddRec->hasNoUnsignedWrap()) {
6720 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6721 if (!UnsignedMinValue.isZero())
6722 ConservativeResult = ConservativeResult.intersectWith(
6723 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6724 }
6725
6726 // If there's no signed wrap, and all the operands except initial value have
6727 // the same sign or zero, the value won't ever be:
6728 // 1: smaller than initial value if operands are non negative,
6729 // 2: bigger than initial value if operands are non positive.
6730 // For both cases, value can not cross signed min/max boundary.
6731 if (AddRec->hasNoSignedWrap()) {
6732 bool AllNonNeg = true;
6733 bool AllNonPos = true;
6734 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6735 if (!isKnownNonNegative(AddRec->getOperand(i)))
6736 AllNonNeg = false;
6737 if (!isKnownNonPositive(AddRec->getOperand(i)))
6738 AllNonPos = false;
6739 }
6740 if (AllNonNeg)
6741 ConservativeResult = ConservativeResult.intersectWith(
6744 RangeType);
6745 else if (AllNonPos)
6746 ConservativeResult = ConservativeResult.intersectWith(
6748 getSignedRangeMax(AddRec->getStart()) +
6749 1),
6750 RangeType);
6751 }
6752
6753 // TODO: non-affine addrec
6754 if (AddRec->isAffine()) {
6755 const SCEV *MaxBEScev =
6757 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6758 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6759
6760 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6761 // MaxBECount's active bits are all <= AddRec's bit width.
6762 if (MaxBECount.getBitWidth() > BitWidth &&
6763 MaxBECount.getActiveBits() <= BitWidth)
6764 MaxBECount = MaxBECount.trunc(BitWidth);
6765 else if (MaxBECount.getBitWidth() < BitWidth)
6766 MaxBECount = MaxBECount.zext(BitWidth);
6767
6768 if (MaxBECount.getBitWidth() == BitWidth) {
6769 auto RangeFromAffine = getRangeForAffineAR(
6770 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6771 ConservativeResult =
6772 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6773
6774 auto RangeFromFactoring = getRangeViaFactoring(
6775 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6776 ConservativeResult =
6777 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6778 }
6779 }
6780
6781 // Now try symbolic BE count and more powerful methods.
6783 const SCEV *SymbolicMaxBECount =
6785 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6786 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6787 AddRec->hasNoSelfWrap()) {
6788 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6789 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6790 ConservativeResult =
6791 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6792 }
6793 }
6794 }
6795
6796 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6797 }
6798 case scUMaxExpr:
6799 case scSMaxExpr:
6800 case scUMinExpr:
6801 case scSMinExpr:
6802 case scSequentialUMinExpr: {
6804 switch (S->getSCEVType()) {
6805 case scUMaxExpr:
6806 ID = Intrinsic::umax;
6807 break;
6808 case scSMaxExpr:
6809 ID = Intrinsic::smax;
6810 break;
6811 case scUMinExpr:
6813 ID = Intrinsic::umin;
6814 break;
6815 case scSMinExpr:
6816 ID = Intrinsic::smin;
6817 break;
6818 default:
6819 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6820 }
6821
6822 const auto *NAry = cast<SCEVNAryExpr>(S);
6823 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6824 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6825 X = X.intrinsic(
6826 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6827 return setRange(S, SignHint,
6828 ConservativeResult.intersectWith(X, RangeType));
6829 }
6830 case scUnknown: {
6831 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6832 Value *V = U->getValue();
6833
6834 // Check if the IR explicitly contains !range metadata.
6835 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6836 if (MDRange)
6837 ConservativeResult =
6838 ConservativeResult.intersectWith(*MDRange, RangeType);
6839
6840 // Use facts about recurrences in the underlying IR. Note that add
6841 // recurrences are AddRecExprs and thus don't hit this path. This
6842 // primarily handles shift recurrences.
6843 auto CR = getRangeForUnknownRecurrence(U);
6844 ConservativeResult = ConservativeResult.intersectWith(CR);
6845
6846 // See if ValueTracking can give us a useful range.
6847 const DataLayout &DL = getDataLayout();
6848 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6849 if (Known.getBitWidth() != BitWidth)
6850 Known = Known.zextOrTrunc(BitWidth);
6851
6852 // ValueTracking may be able to compute a tighter result for the number of
6853 // sign bits than for the value of those sign bits.
6854 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6855 if (U->getType()->isPointerTy()) {
6856 // If the pointer size is larger than the index size type, this can cause
6857 // NS to be larger than BitWidth. So compensate for this.
6858 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6859 int ptrIdxDiff = ptrSize - BitWidth;
6860 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6861 NS -= ptrIdxDiff;
6862 }
6863
6864 if (NS > 1) {
6865 // If we know any of the sign bits, we know all of the sign bits.
6866 if (!Known.Zero.getHiBits(NS).isZero())
6867 Known.Zero.setHighBits(NS);
6868 if (!Known.One.getHiBits(NS).isZero())
6869 Known.One.setHighBits(NS);
6870 }
6871
6872 if (Known.getMinValue() != Known.getMaxValue() + 1)
6873 ConservativeResult = ConservativeResult.intersectWith(
6874 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6875 RangeType);
6876 if (NS > 1)
6877 ConservativeResult = ConservativeResult.intersectWith(
6879 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6880 RangeType);
6881
6882 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6883 // Strengthen the range if the underlying IR value is a
6884 // global/alloca/heap allocation using the size of the object.
6885 ObjectSizeOpts Opts;
6886 Opts.RoundToAlign = false;
6887 Opts.NullIsUnknownSize = true;
6888 uint64_t ObjSize;
6889 if ((isa<GlobalVariable>(V) || isa<AllocaInst>(V) ||
6890 isAllocationFn(V, &TLI)) &&
6891 getObjectSize(V, ObjSize, DL, &TLI, Opts) && ObjSize > 1) {
6892 // The highest address the object can start is ObjSize bytes before the
6893 // end (unsigned max value). If this value is not a multiple of the
6894 // alignment, the last possible start value is the next lowest multiple
6895 // of the alignment. Note: The computations below cannot overflow,
6896 // because if they would there's no possible start address for the
6897 // object.
6898 APInt MaxVal = APInt::getMaxValue(BitWidth) - APInt(BitWidth, ObjSize);
6899 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6900 uint64_t Rem = MaxVal.urem(Align);
6901 MaxVal -= APInt(BitWidth, Rem);
6902 APInt MinVal = APInt::getZero(BitWidth);
6903 if (llvm::isKnownNonZero(V, DL))
6904 MinVal = Align;
6905 ConservativeResult = ConservativeResult.intersectWith(
6906 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6907 }
6908 }
6909
6910 // A range of Phi is a subset of union of all ranges of its input.
6911 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6912 // Make sure that we do not run over cycled Phis.
6913 if (PendingPhiRanges.insert(Phi).second) {
6914 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6915
6916 for (const auto &Op : Phi->operands()) {
6917 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6918 RangeFromOps = RangeFromOps.unionWith(OpRange);
6919 // No point to continue if we already have a full set.
6920 if (RangeFromOps.isFullSet())
6921 break;
6922 }
6923 ConservativeResult =
6924 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6925 bool Erased = PendingPhiRanges.erase(Phi);
6926 assert(Erased && "Failed to erase Phi properly?");
6927 (void)Erased;
6928 }
6929 }
6930
6931 // vscale can't be equal to zero
6932 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6933 if (II->getIntrinsicID() == Intrinsic::vscale) {
6935 ConservativeResult = ConservativeResult.difference(Disallowed);
6936 }
6937
6938 return setRange(U, SignHint, std::move(ConservativeResult));
6939 }
6940 case scCouldNotCompute:
6941 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6942 }
6943
6944 return setRange(S, SignHint, std::move(ConservativeResult));
6945}
6946
6947// Given a StartRange, Step and MaxBECount for an expression compute a range of
6948// values that the expression can take. Initially, the expression has a value
6949// from StartRange and then is changed by Step up to MaxBECount times. Signed
6950// argument defines if we treat Step as signed or unsigned.
6952 const ConstantRange &StartRange,
6953 const APInt &MaxBECount,
6954 bool Signed) {
6955 unsigned BitWidth = Step.getBitWidth();
6956 assert(BitWidth == StartRange.getBitWidth() &&
6957 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6958 // If either Step or MaxBECount is 0, then the expression won't change, and we
6959 // just need to return the initial range.
6960 if (Step == 0 || MaxBECount == 0)
6961 return StartRange;
6962
6963 // If we don't know anything about the initial value (i.e. StartRange is
6964 // FullRange), then we don't know anything about the final range either.
6965 // Return FullRange.
6966 if (StartRange.isFullSet())
6967 return ConstantRange::getFull(BitWidth);
6968
6969 // If Step is signed and negative, then we use its absolute value, but we also
6970 // note that we're moving in the opposite direction.
6971 bool Descending = Signed && Step.isNegative();
6972
6973 if (Signed)
6974 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6975 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6976 // This equations hold true due to the well-defined wrap-around behavior of
6977 // APInt.
6978 Step = Step.abs();
6979
6980 // Check if Offset is more than full span of BitWidth. If it is, the
6981 // expression is guaranteed to overflow.
6982 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6983 return ConstantRange::getFull(BitWidth);
6984
6985 // Offset is by how much the expression can change. Checks above guarantee no
6986 // overflow here.
6987 APInt Offset = Step * MaxBECount;
6988
6989 // Minimum value of the final range will match the minimal value of StartRange
6990 // if the expression is increasing and will be decreased by Offset otherwise.
6991 // Maximum value of the final range will match the maximal value of StartRange
6992 // if the expression is decreasing and will be increased by Offset otherwise.
6993 APInt StartLower = StartRange.getLower();
6994 APInt StartUpper = StartRange.getUpper() - 1;
6995 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
6996 : (StartUpper + std::move(Offset));
6997
6998 // It's possible that the new minimum/maximum value will fall into the initial
6999 // range (due to wrap around). This means that the expression can take any
7000 // value in this bitwidth, and we have to return full range.
7001 if (StartRange.contains(MovedBoundary))
7002 return ConstantRange::getFull(BitWidth);
7003
7004 APInt NewLower =
7005 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7006 APInt NewUpper =
7007 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7008 NewUpper += 1;
7009
7010 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7011 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7012}
7013
7014ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7015 const SCEV *Step,
7016 const APInt &MaxBECount) {
7017 assert(getTypeSizeInBits(Start->getType()) ==
7018 getTypeSizeInBits(Step->getType()) &&
7019 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7020 "mismatched bit widths");
7021
7022 // First, consider step signed.
7023 ConstantRange StartSRange = getSignedRange(Start);
7024 ConstantRange StepSRange = getSignedRange(Step);
7025
7026 // If Step can be both positive and negative, we need to find ranges for the
7027 // maximum absolute step values in both directions and union them.
7029 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7031 StartSRange, MaxBECount,
7032 /* Signed = */ true));
7033
7034 // Next, consider step unsigned.
7036 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7037 /* Signed = */ false);
7038
7039 // Finally, intersect signed and unsigned ranges.
7041}
7042
7043ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7044 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7045 ScalarEvolution::RangeSignHint SignHint) {
7046 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7047 assert(AddRec->hasNoSelfWrap() &&
7048 "This only works for non-self-wrapping AddRecs!");
7049 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7050 const SCEV *Step = AddRec->getStepRecurrence(*this);
7051 // Only deal with constant step to save compile time.
7052 if (!isa<SCEVConstant>(Step))
7053 return ConstantRange::getFull(BitWidth);
7054 // Let's make sure that we can prove that we do not self-wrap during
7055 // MaxBECount iterations. We need this because MaxBECount is a maximum
7056 // iteration count estimate, and we might infer nw from some exit for which we
7057 // do not know max exit count (or any other side reasoning).
7058 // TODO: Turn into assert at some point.
7059 if (getTypeSizeInBits(MaxBECount->getType()) >
7060 getTypeSizeInBits(AddRec->getType()))
7061 return ConstantRange::getFull(BitWidth);
7062 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7063 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7064 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7065 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7066 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7067 MaxItersWithoutWrap))
7068 return ConstantRange::getFull(BitWidth);
7069
7070 ICmpInst::Predicate LEPred =
7072 ICmpInst::Predicate GEPred =
7074 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7075
7076 // We know that there is no self-wrap. Let's take Start and End values and
7077 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7078 // the iteration. They either lie inside the range [Min(Start, End),
7079 // Max(Start, End)] or outside it:
7080 //
7081 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7082 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7083 //
7084 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7085 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7086 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7087 // Start <= End and step is positive, or Start >= End and step is negative.
7088 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7089 ConstantRange StartRange = getRangeRef(Start, SignHint);
7090 ConstantRange EndRange = getRangeRef(End, SignHint);
7091 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7092 // If they already cover full iteration space, we will know nothing useful
7093 // even if we prove what we want to prove.
7094 if (RangeBetween.isFullSet())
7095 return RangeBetween;
7096 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7097 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7098 : RangeBetween.isWrappedSet();
7099 if (IsWrappedSet)
7100 return ConstantRange::getFull(BitWidth);
7101
7102 if (isKnownPositive(Step) &&
7103 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7104 return RangeBetween;
7105 if (isKnownNegative(Step) &&
7106 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7107 return RangeBetween;
7108 return ConstantRange::getFull(BitWidth);
7109}
7110
7111ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7112 const SCEV *Step,
7113 const APInt &MaxBECount) {
7114 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7115 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7116
7117 unsigned BitWidth = MaxBECount.getBitWidth();
7118 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7119 getTypeSizeInBits(Step->getType()) == BitWidth &&
7120 "mismatched bit widths");
7121
7122 struct SelectPattern {
7123 Value *Condition = nullptr;
7124 APInt TrueValue;
7125 APInt FalseValue;
7126
7127 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7128 const SCEV *S) {
7129 std::optional<unsigned> CastOp;
7130 APInt Offset(BitWidth, 0);
7131
7133 "Should be!");
7134
7135 // Peel off a constant offset:
7136 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7137 // In the future we could consider being smarter here and handle
7138 // {Start+Step,+,Step} too.
7139 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7140 return;
7141
7142 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7143 S = SA->getOperand(1);
7144 }
7145
7146 // Peel off a cast operation
7147 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7148 CastOp = SCast->getSCEVType();
7149 S = SCast->getOperand();
7150 }
7151
7152 using namespace llvm::PatternMatch;
7153
7154 auto *SU = dyn_cast<SCEVUnknown>(S);
7155 const APInt *TrueVal, *FalseVal;
7156 if (!SU ||
7157 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7158 m_APInt(FalseVal)))) {
7159 Condition = nullptr;
7160 return;
7161 }
7162
7163 TrueValue = *TrueVal;
7164 FalseValue = *FalseVal;
7165
7166 // Re-apply the cast we peeled off earlier
7167 if (CastOp)
7168 switch (*CastOp) {
7169 default:
7170 llvm_unreachable("Unknown SCEV cast type!");
7171
7172 case scTruncate:
7173 TrueValue = TrueValue.trunc(BitWidth);
7174 FalseValue = FalseValue.trunc(BitWidth);
7175 break;
7176 case scZeroExtend:
7177 TrueValue = TrueValue.zext(BitWidth);
7178 FalseValue = FalseValue.zext(BitWidth);
7179 break;
7180 case scSignExtend:
7181 TrueValue = TrueValue.sext(BitWidth);
7182 FalseValue = FalseValue.sext(BitWidth);
7183 break;
7184 }
7185
7186 // Re-apply the constant offset we peeled off earlier
7187 TrueValue += Offset;
7188 FalseValue += Offset;
7189 }
7190
7191 bool isRecognized() { return Condition != nullptr; }
7192 };
7193
7194 SelectPattern StartPattern(*this, BitWidth, Start);
7195 if (!StartPattern.isRecognized())
7196 return ConstantRange::getFull(BitWidth);
7197
7198 SelectPattern StepPattern(*this, BitWidth, Step);
7199 if (!StepPattern.isRecognized())
7200 return ConstantRange::getFull(BitWidth);
7201
7202 if (StartPattern.Condition != StepPattern.Condition) {
7203 // We don't handle this case today; but we could, by considering four
7204 // possibilities below instead of two. I'm not sure if there are cases where
7205 // that will help over what getRange already does, though.
7206 return ConstantRange::getFull(BitWidth);
7207 }
7208
7209 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7210 // construct arbitrary general SCEV expressions here. This function is called
7211 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7212 // say) can end up caching a suboptimal value.
7213
7214 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7215 // C2352 and C2512 (otherwise it isn't needed).
7216
7217 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7218 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7219 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7220 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7221
7222 ConstantRange TrueRange =
7223 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7224 ConstantRange FalseRange =
7225 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7226
7227 return TrueRange.unionWith(FalseRange);
7228}
7229
7230SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7231 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7232 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7233
7234 // Return early if there are no flags to propagate to the SCEV.
7236 if (BinOp->hasNoUnsignedWrap())
7238 if (BinOp->hasNoSignedWrap())
7240 if (Flags == SCEV::FlagAnyWrap)
7241 return SCEV::FlagAnyWrap;
7242
7243 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7244}
7245
7246const Instruction *
7247ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7248 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7249 return &*AddRec->getLoop()->getHeader()->begin();
7250 if (auto *U = dyn_cast<SCEVUnknown>(S))
7251 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7252 return I;
7253 return nullptr;
7254}
7255
7256const Instruction *
7257ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7258 bool &Precise) {
7259 Precise = true;
7260 // Do a bounded search of the def relation of the requested SCEVs.
7263 auto pushOp = [&](const SCEV *S) {
7264 if (!Visited.insert(S).second)
7265 return;
7266 // Threshold of 30 here is arbitrary.
7267 if (Visited.size() > 30) {
7268 Precise = false;
7269 return;
7270 }
7271 Worklist.push_back(S);
7272 };
7273
7274 for (const auto *S : Ops)
7275 pushOp(S);
7276
7277 const Instruction *Bound = nullptr;
7278 while (!Worklist.empty()) {
7279 auto *S = Worklist.pop_back_val();
7280 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7281 if (!Bound || DT.dominates(Bound, DefI))
7282 Bound = DefI;
7283 } else {
7284 for (const auto *Op : S->operands())
7285 pushOp(Op);
7286 }
7287 }
7288 return Bound ? Bound : &*F.getEntryBlock().begin();
7289}
7290
7291const Instruction *
7292ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7293 bool Discard;
7294 return getDefiningScopeBound(Ops, Discard);
7295}
7296
7297bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7298 const Instruction *B) {
7299 if (A->getParent() == B->getParent() &&
7301 B->getIterator()))
7302 return true;
7303
7304 auto *BLoop = LI.getLoopFor(B->getParent());
7305 if (BLoop && BLoop->getHeader() == B->getParent() &&
7306 BLoop->getLoopPreheader() == A->getParent() &&
7308 A->getParent()->end()) &&
7309 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7310 B->getIterator()))
7311 return true;
7312 return false;
7313}
7314
7315
7316bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7317 // Only proceed if we can prove that I does not yield poison.
7319 return false;
7320
7321 // At this point we know that if I is executed, then it does not wrap
7322 // according to at least one of NSW or NUW. If I is not executed, then we do
7323 // not know if the calculation that I represents would wrap. Multiple
7324 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7325 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7326 // derived from other instructions that map to the same SCEV. We cannot make
7327 // that guarantee for cases where I is not executed. So we need to find a
7328 // upper bound on the defining scope for the SCEV, and prove that I is
7329 // executed every time we enter that scope. When the bounding scope is a
7330 // loop (the common case), this is equivalent to proving I executes on every
7331 // iteration of that loop.
7333 for (const Use &Op : I->operands()) {
7334 // I could be an extractvalue from a call to an overflow intrinsic.
7335 // TODO: We can do better here in some cases.
7336 if (isSCEVable(Op->getType()))
7337 SCEVOps.push_back(getSCEV(Op));
7338 }
7339 auto *DefI = getDefiningScopeBound(SCEVOps);
7340 return isGuaranteedToTransferExecutionTo(DefI, I);
7341}
7342
7343bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7344 // If we know that \c I can never be poison period, then that's enough.
7345 if (isSCEVExprNeverPoison(I))
7346 return true;
7347
7348 // If the loop only has one exit, then we know that, if the loop is entered,
7349 // any instruction dominating that exit will be executed. If any such
7350 // instruction would result in UB, the addrec cannot be poison.
7351 //
7352 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7353 // also handles uses outside the loop header (they just need to dominate the
7354 // single exit).
7355
7356 auto *ExitingBB = L->getExitingBlock();
7357 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7358 return false;
7359
7362
7363 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7364 // things that are known to be poison under that assumption go on the
7365 // Worklist.
7366 KnownPoison.insert(I);
7367 Worklist.push_back(I);
7368
7369 while (!Worklist.empty()) {
7370 const Instruction *Poison = Worklist.pop_back_val();
7371
7372 for (const Use &U : Poison->uses()) {
7373 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7374 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7375 DT.dominates(PoisonUser->getParent(), ExitingBB))
7376 return true;
7377
7378 if (propagatesPoison(U) && L->contains(PoisonUser))
7379 if (KnownPoison.insert(PoisonUser).second)
7380 Worklist.push_back(PoisonUser);
7381 }
7382 }
7383
7384 return false;
7385}
7386
7387ScalarEvolution::LoopProperties
7388ScalarEvolution::getLoopProperties(const Loop *L) {
7389 using LoopProperties = ScalarEvolution::LoopProperties;
7390
7391 auto Itr = LoopPropertiesCache.find(L);
7392 if (Itr == LoopPropertiesCache.end()) {
7393 auto HasSideEffects = [](Instruction *I) {
7394 if (auto *SI = dyn_cast<StoreInst>(I))
7395 return !SI->isSimple();
7396
7397 return I->mayThrow() || I->mayWriteToMemory();
7398 };
7399
7400 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7401 /*HasNoSideEffects*/ true};
7402
7403 for (auto *BB : L->getBlocks())
7404 for (auto &I : *BB) {
7406 LP.HasNoAbnormalExits = false;
7407 if (HasSideEffects(&I))
7408 LP.HasNoSideEffects = false;
7409 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7410 break; // We're already as pessimistic as we can get.
7411 }
7412
7413 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7414 assert(InsertPair.second && "We just checked!");
7415 Itr = InsertPair.first;
7416 }
7417
7418 return Itr->second;
7419}
7420
7422 // A mustprogress loop without side effects must be finite.
7423 // TODO: The check used here is very conservative. It's only *specific*
7424 // side effects which are well defined in infinite loops.
7425 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7426}
7427
7428const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7429 // Worklist item with a Value and a bool indicating whether all operands have
7430 // been visited already.
7433
7434 Stack.emplace_back(V, true);
7435 Stack.emplace_back(V, false);
7436 while (!Stack.empty()) {
7437 auto E = Stack.pop_back_val();
7438 Value *CurV = E.getPointer();
7439
7440 if (getExistingSCEV(CurV))
7441 continue;
7442
7444 const SCEV *CreatedSCEV = nullptr;
7445 // If all operands have been visited already, create the SCEV.
7446 if (E.getInt()) {
7447 CreatedSCEV = createSCEV(CurV);
7448 } else {
7449 // Otherwise get the operands we need to create SCEV's for before creating
7450 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7451 // just use it.
7452 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7453 }
7454
7455 if (CreatedSCEV) {
7456 insertValueToMap(CurV, CreatedSCEV);
7457 } else {
7458 // Queue CurV for SCEV creation, followed by its's operands which need to
7459 // be constructed first.
7460 Stack.emplace_back(CurV, true);
7461 for (Value *Op : Ops)
7462 Stack.emplace_back(Op, false);
7463 }
7464 }
7465
7466 return getExistingSCEV(V);
7467}
7468
7469const SCEV *
7470ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7471 if (!isSCEVable(V->getType()))
7472 return getUnknown(V);
7473
7474 if (Instruction *I = dyn_cast<Instruction>(V)) {
7475 // Don't attempt to analyze instructions in blocks that aren't
7476 // reachable. Such instructions don't matter, and they aren't required
7477 // to obey basic rules for definitions dominating uses which this
7478 // analysis depends on.
7479 if (!DT.isReachableFromEntry(I->getParent()))
7480 return getUnknown(PoisonValue::get(V->getType()));
7481 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7482 return getConstant(CI);
7483 else if (isa<GlobalAlias>(V))
7484 return getUnknown(V);
7485 else if (!isa<ConstantExpr>(V))
7486 return getUnknown(V);
7487
7488 Operator *U = cast<Operator>(V);
7489 if (auto BO =
7490 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7491 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7492 switch (BO->Opcode) {
7493 case Instruction::Add:
7494 case Instruction::Mul: {
7495 // For additions and multiplications, traverse add/mul chains for which we
7496 // can potentially create a single SCEV, to reduce the number of
7497 // get{Add,Mul}Expr calls.
7498 do {
7499 if (BO->Op) {
7500 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7501 Ops.push_back(BO->Op);
7502 break;
7503 }
7504 }
7505 Ops.push_back(BO->RHS);
7506 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7507 dyn_cast<Instruction>(V));
7508 if (!NewBO ||
7509 (BO->Opcode == Instruction::Add &&
7510 (NewBO->Opcode != Instruction::Add &&
7511 NewBO->Opcode != Instruction::Sub)) ||
7512 (BO->Opcode == Instruction::Mul &&
7513 NewBO->Opcode != Instruction::Mul)) {
7514 Ops.push_back(BO->LHS);
7515 break;
7516 }
7517 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7518 // requires a SCEV for the LHS.
7519 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7520 auto *I = dyn_cast<Instruction>(BO->Op);
7521 if (I && programUndefinedIfPoison(I)) {
7522 Ops.push_back(BO->LHS);
7523 break;
7524 }
7525 }
7526 BO = NewBO;
7527 } while (true);
7528 return nullptr;
7529 }
7530 case Instruction::Sub:
7531 case Instruction::UDiv:
7532 case Instruction::URem:
7533 break;
7534 case Instruction::AShr:
7535 case Instruction::Shl:
7536 case Instruction::Xor:
7537 if (!IsConstArg)
7538 return nullptr;
7539 break;
7540 case Instruction::And:
7541 case Instruction::Or:
7542 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7543 return nullptr;
7544 break;
7545 case Instruction::LShr:
7546 return getUnknown(V);
7547 default:
7548 llvm_unreachable("Unhandled binop");
7549 break;
7550 }
7551
7552 Ops.push_back(BO->LHS);
7553 Ops.push_back(BO->RHS);
7554 return nullptr;
7555 }
7556
7557 switch (U->getOpcode()) {
7558 case Instruction::Trunc:
7559 case Instruction::ZExt:
7560 case Instruction::SExt:
7561 case Instruction::PtrToInt:
7562 Ops.push_back(U->getOperand(0));
7563 return nullptr;
7564
7565 case Instruction::BitCast:
7566 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7567 Ops.push_back(U->getOperand(0));
7568 return nullptr;
7569 }
7570 return getUnknown(V);
7571
7572 case Instruction::SDiv:
7573 case Instruction::SRem:
7574 Ops.push_back(U->getOperand(0));
7575 Ops.push_back(U->getOperand(1));
7576 return nullptr;
7577
7578 case Instruction::GetElementPtr:
7579 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7580 "GEP source element type must be sized");
7581 for (Value *Index : U->operands())
7582 Ops.push_back(Index);
7583 return nullptr;
7584
7585 case Instruction::IntToPtr:
7586 return getUnknown(V);
7587
7588 case Instruction::PHI:
7589 // Keep constructing SCEVs' for phis recursively for now.
7590 return nullptr;
7591
7592 case Instruction::Select: {
7593 // Check if U is a select that can be simplified to a SCEVUnknown.
7594 auto CanSimplifyToUnknown = [this, U]() {
7595 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7596 return false;
7597
7598 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7599 if (!ICI)
7600 return false;
7601 Value *LHS = ICI->getOperand(0);
7602 Value *RHS = ICI->getOperand(1);
7603 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7604 ICI->getPredicate() == CmpInst::ICMP_NE) {
7605 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7606 return true;
7607 } else if (getTypeSizeInBits(LHS->getType()) >
7608 getTypeSizeInBits(U->getType()))
7609 return true;
7610 return false;
7611 };
7612 if (CanSimplifyToUnknown())
7613 return getUnknown(U);
7614
7615 for (Value *Inc : U->operands())
7616 Ops.push_back(Inc);
7617 return nullptr;
7618 break;
7619 }
7620 case Instruction::Call:
7621 case Instruction::Invoke:
7622 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7623 Ops.push_back(RV);
7624 return nullptr;
7625 }
7626
7627 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7628 switch (II->getIntrinsicID()) {
7629 case Intrinsic::abs:
7630 Ops.push_back(II->getArgOperand(0));
7631 return nullptr;
7632 case Intrinsic::umax:
7633 case Intrinsic::umin:
7634 case Intrinsic::smax:
7635 case Intrinsic::smin:
7636 case Intrinsic::usub_sat:
7637 case Intrinsic::uadd_sat:
7638 Ops.push_back(II->getArgOperand(0));
7639 Ops.push_back(II->getArgOperand(1));
7640 return nullptr;
7641 case Intrinsic::start_loop_iterations:
7642 case Intrinsic::annotation:
7643 case Intrinsic::ptr_annotation:
7644 Ops.push_back(II->getArgOperand(0));
7645 return nullptr;
7646 default:
7647 break;
7648 }
7649 }
7650 break;
7651 }
7652
7653 return nullptr;
7654}
7655
7656const SCEV *ScalarEvolution::createSCEV(Value *V) {
7657 if (!isSCEVable(V->getType()))
7658 return getUnknown(V);
7659
7660 if (Instruction *I = dyn_cast<Instruction>(V)) {
7661 // Don't attempt to analyze instructions in blocks that aren't
7662 // reachable. Such instructions don't matter, and they aren't required
7663 // to obey basic rules for definitions dominating uses which this
7664 // analysis depends on.
7665 if (!DT.isReachableFromEntry(I->getParent()))
7666 return getUnknown(PoisonValue::get(V->getType()));
7667 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7668 return getConstant(CI);
7669 else if (isa<GlobalAlias>(V))
7670 return getUnknown(V);
7671 else if (!isa<ConstantExpr>(V))
7672 return getUnknown(V);
7673
7674 const SCEV *LHS;
7675 const SCEV *RHS;
7676
7677 Operator *U = cast<Operator>(V);
7678 if (auto BO =
7679 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7680 switch (BO->Opcode) {
7681 case Instruction::Add: {
7682 // The simple thing to do would be to just call getSCEV on both operands
7683 // and call getAddExpr with the result. However if we're looking at a
7684 // bunch of things all added together, this can be quite inefficient,
7685 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7686 // Instead, gather up all the operands and make a single getAddExpr call.
7687 // LLVM IR canonical form means we need only traverse the left operands.
7689 do {
7690 if (BO->Op) {
7691 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7692 AddOps.push_back(OpSCEV);
7693 break;
7694 }
7695
7696 // If a NUW or NSW flag can be applied to the SCEV for this
7697 // addition, then compute the SCEV for this addition by itself
7698 // with a separate call to getAddExpr. We need to do that
7699 // instead of pushing the operands of the addition onto AddOps,
7700 // since the flags are only known to apply to this particular
7701 // addition - they may not apply to other additions that can be
7702 // formed with operands from AddOps.
7703 const SCEV *RHS = getSCEV(BO->RHS);
7704 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7705 if (Flags != SCEV::FlagAnyWrap) {
7706 const SCEV *LHS = getSCEV(BO->LHS);
7707 if (BO->Opcode == Instruction::Sub)
7708 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7709 else
7710 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7711 break;
7712 }
7713 }
7714
7715 if (BO->Opcode == Instruction::Sub)
7716 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7717 else
7718 AddOps.push_back(getSCEV(BO->RHS));
7719
7720 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7721 dyn_cast<Instruction>(V));
7722 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7723 NewBO->Opcode != Instruction::Sub)) {
7724 AddOps.push_back(getSCEV(BO->LHS));
7725 break;
7726 }
7727 BO = NewBO;
7728 } while (true);
7729
7730 return getAddExpr(AddOps);
7731 }
7732
7733 case Instruction::Mul: {
7735 do {
7736 if (BO->Op) {
7737 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7738 MulOps.push_back(OpSCEV);
7739 break;
7740 }
7741
7742 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7743 if (Flags != SCEV::FlagAnyWrap) {
7744 LHS = getSCEV(BO->LHS);
7745 RHS = getSCEV(BO->RHS);
7746 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7747 break;
7748 }
7749 }
7750
7751 MulOps.push_back(getSCEV(BO->RHS));
7752 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7753 dyn_cast<Instruction>(V));
7754 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7755 MulOps.push_back(getSCEV(BO->LHS));
7756 break;
7757 }
7758 BO = NewBO;
7759 } while (true);
7760
7761 return getMulExpr(MulOps);
7762 }
7763 case Instruction::UDiv:
7764 LHS = getSCEV(BO->LHS);
7765 RHS = getSCEV(BO->RHS);
7766 return getUDivExpr(LHS, RHS);
7767 case Instruction::URem:
7768 LHS = getSCEV(BO->LHS);
7769 RHS = getSCEV(BO->RHS);
7770 return getURemExpr(LHS, RHS);
7771 case Instruction::Sub: {
7773 if (BO->Op)
7774 Flags = getNoWrapFlagsFromUB(BO->Op);
7775 LHS = getSCEV(BO->LHS);
7776 RHS = getSCEV(BO->RHS);
7777 return getMinusSCEV(LHS, RHS, Flags);
7778 }
7779 case Instruction::And:
7780 // For an expression like x&255 that merely masks off the high bits,
7781 // use zext(trunc(x)) as the SCEV expression.
7782 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7783 if (CI->isZero())
7784 return getSCEV(BO->RHS);
7785 if (CI->isMinusOne())
7786 return getSCEV(BO->LHS);
7787 const APInt &A = CI->getValue();
7788
7789 // Instcombine's ShrinkDemandedConstant may strip bits out of
7790 // constants, obscuring what would otherwise be a low-bits mask.
7791 // Use computeKnownBits to compute what ShrinkDemandedConstant
7792 // knew about to reconstruct a low-bits mask value.
7793 unsigned LZ = A.countl_zero();
7794 unsigned TZ = A.countr_zero();
7795 unsigned BitWidth = A.getBitWidth();
7796 KnownBits Known(BitWidth);
7797 computeKnownBits(BO->LHS, Known, getDataLayout(),
7798 0, &AC, nullptr, &DT);
7799
7800 APInt EffectiveMask =
7801 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7802 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7803 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7804 const SCEV *LHS = getSCEV(BO->LHS);
7805 const SCEV *ShiftedLHS = nullptr;
7806 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7807 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7808 // For an expression like (x * 8) & 8, simplify the multiply.
7809 unsigned MulZeros = OpC->getAPInt().countr_zero();
7810 unsigned GCD = std::min(MulZeros, TZ);
7811 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7813 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7814 append_range(MulOps, LHSMul->operands().drop_front());
7815 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7816 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7817 }
7818 }
7819 if (!ShiftedLHS)
7820 ShiftedLHS = getUDivExpr(LHS, MulCount);
7821 return getMulExpr(
7823 getTruncateExpr(ShiftedLHS,
7824 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7825 BO->LHS->getType()),
7826 MulCount);
7827 }
7828 }
7829 // Binary `and` is a bit-wise `umin`.
7830 if (BO->LHS->getType()->isIntegerTy(1)) {
7831 LHS = getSCEV(BO->LHS);
7832 RHS = getSCEV(BO->RHS);
7833 return getUMinExpr(LHS, RHS);
7834 }
7835 break;
7836
7837 case Instruction::Or:
7838 // Binary `or` is a bit-wise `umax`.
7839 if (BO->LHS->getType()->isIntegerTy(1)) {
7840 LHS = getSCEV(BO->LHS);
7841 RHS = getSCEV(BO->RHS);
7842 return getUMaxExpr(LHS, RHS);
7843 }
7844 break;
7845
7846 case Instruction::Xor:
7847 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7848 // If the RHS of xor is -1, then this is a not operation.
7849 if (CI->isMinusOne())
7850 return getNotSCEV(getSCEV(BO->LHS));
7851
7852 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7853 // This is a variant of the check for xor with -1, and it handles
7854 // the case where instcombine has trimmed non-demanded bits out
7855 // of an xor with -1.
7856 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7857 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7858 if (LBO->getOpcode() == Instruction::And &&
7859 LCI->getValue() == CI->getValue())
7860 if (const SCEVZeroExtendExpr *Z =
7861 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7862 Type *UTy = BO->LHS->getType();
7863 const SCEV *Z0 = Z->getOperand();
7864 Type *Z0Ty = Z0->getType();
7865 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7866
7867 // If C is a low-bits mask, the zero extend is serving to
7868 // mask off the high bits. Complement the operand and
7869 // re-apply the zext.
7870 if (CI->getValue().isMask(Z0TySize))
7871 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7872
7873 // If C is a single bit, it may be in the sign-bit position
7874 // before the zero-extend. In this case, represent the xor
7875 // using an add, which is equivalent, and re-apply the zext.
7876 APInt Trunc = CI->getValue().trunc(Z0TySize);
7877 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7878 Trunc.isSignMask())
7879 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7880 UTy);
7881 }
7882 }
7883 break;
7884
7885 case Instruction::Shl:
7886 // Turn shift left of a constant amount into a multiply.
7887 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7888 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7889
7890 // If the shift count is not less than the bitwidth, the result of
7891 // the shift is undefined. Don't try to analyze it, because the
7892 // resolution chosen here may differ from the resolution chosen in
7893 // other parts of the compiler.
7894 if (SA->getValue().uge(BitWidth))
7895 break;
7896
7897 // We can safely preserve the nuw flag in all cases. It's also safe to
7898 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7899 // requires special handling. It can be preserved as long as we're not
7900 // left shifting by bitwidth - 1.
7901 auto Flags = SCEV::FlagAnyWrap;
7902 if (BO->Op) {
7903 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7904 if ((MulFlags & SCEV::FlagNSW) &&
7905 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7907 if (MulFlags & SCEV::FlagNUW)
7909 }
7910
7911 ConstantInt *X = ConstantInt::get(
7912 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7913 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7914 }
7915 break;
7916
7917 case Instruction::AShr:
7918 // AShr X, C, where C is a constant.
7919 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7920 if (!CI)
7921 break;
7922
7923 Type *OuterTy = BO->LHS->getType();
7925 // If the shift count is not less than the bitwidth, the result of
7926 // the shift is undefined. Don't try to analyze it, because the
7927 // resolution chosen here may differ from the resolution chosen in
7928 // other parts of the compiler.
7929 if (CI->getValue().uge(BitWidth))
7930 break;
7931
7932 if (CI->isZero())
7933 return getSCEV(BO->LHS); // shift by zero --> noop
7934
7935 uint64_t AShrAmt = CI->getZExtValue();
7936 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7937
7938 Operator *L = dyn_cast<Operator>(BO->LHS);
7939 const SCEV *AddTruncateExpr = nullptr;
7940 ConstantInt *ShlAmtCI = nullptr;
7941 const SCEV *AddConstant = nullptr;
7942
7943 if (L && L->getOpcode() == Instruction::Add) {
7944 // X = Shl A, n
7945 // Y = Add X, c
7946 // Z = AShr Y, m
7947 // n, c and m are constants.
7948
7949 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7950 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7951 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7952 if (AddOperandCI) {
7953 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
7954 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
7955 // since we truncate to TruncTy, the AddConstant should be of the
7956 // same type, so create a new Constant with type same as TruncTy.
7957 // Also, the Add constant should be shifted right by AShr amount.
7958 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
7959 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
7960 // we model the expression as sext(add(trunc(A), c << n)), since the
7961 // sext(trunc) part is already handled below, we create a
7962 // AddExpr(TruncExp) which will be used later.
7963 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7964 }
7965 }
7966 } else if (L && L->getOpcode() == Instruction::Shl) {
7967 // X = Shl A, n
7968 // Y = AShr X, m
7969 // Both n and m are constant.
7970
7971 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7972 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7973 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7974 }
7975
7976 if (AddTruncateExpr && ShlAmtCI) {
7977 // We can merge the two given cases into a single SCEV statement,
7978 // incase n = m, the mul expression will be 2^0, so it gets resolved to
7979 // a simpler case. The following code handles the two cases:
7980 //
7981 // 1) For a two-shift sext-inreg, i.e. n = m,
7982 // use sext(trunc(x)) as the SCEV expression.
7983 //
7984 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7985 // expression. We already checked that ShlAmt < BitWidth, so
7986 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7987 // ShlAmt - AShrAmt < Amt.
7988 const APInt &ShlAmt = ShlAmtCI->getValue();
7989 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
7991 ShlAmtCI->getZExtValue() - AShrAmt);
7992 const SCEV *CompositeExpr =
7993 getMulExpr(AddTruncateExpr, getConstant(Mul));
7994 if (L->getOpcode() != Instruction::Shl)
7995 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
7996
7997 return getSignExtendExpr(CompositeExpr, OuterTy);
7998 }
7999 }
8000 break;
8001 }
8002 }
8003
8004 switch (U->getOpcode()) {
8005 case Instruction::Trunc:
8006 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8007
8008 case Instruction::ZExt:
8009 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8010
8011 case Instruction::SExt:
8012 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8013 dyn_cast<Instruction>(V))) {
8014 // The NSW flag of a subtract does not always survive the conversion to
8015 // A + (-1)*B. By pushing sign extension onto its operands we are much
8016 // more likely to preserve NSW and allow later AddRec optimisations.
8017 //
8018 // NOTE: This is effectively duplicating this logic from getSignExtend:
8019 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8020 // but by that point the NSW information has potentially been lost.
8021 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8022 Type *Ty = U->getType();
8023 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8024 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8025 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8026 }
8027 }
8028 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8029
8030 case Instruction::BitCast:
8031 // BitCasts are no-op casts so we just eliminate the cast.
8032 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8033 return getSCEV(U->getOperand(0));
8034 break;
8035
8036 case Instruction::PtrToInt: {
8037 // Pointer to integer cast is straight-forward, so do model it.
8038 const SCEV *Op = getSCEV(U->getOperand(0));
8039 Type *DstIntTy = U->getType();
8040 // But only if effective SCEV (integer) type is wide enough to represent
8041 // all possible pointer values.
8042 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8043 if (isa<SCEVCouldNotCompute>(IntOp))
8044 return getUnknown(V);
8045 return IntOp;
8046 }
8047 case Instruction::IntToPtr:
8048 // Just don't deal with inttoptr casts.
8049 return getUnknown(V);
8050
8051 case Instruction::SDiv:
8052 // If both operands are non-negative, this is just an udiv.
8053 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8054 isKnownNonNegative(getSCEV(U->getOperand(1))))
8055 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8056 break;
8057
8058 case Instruction::SRem:
8059 // If both operands are non-negative, this is just an urem.
8060 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8061 isKnownNonNegative(getSCEV(U->getOperand(1))))
8062 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8063 break;
8064
8065 case Instruction::GetElementPtr:
8066 return createNodeForGEP(cast<GEPOperator>(U));
8067
8068 case Instruction::PHI:
8069 return createNodeForPHI(cast<PHINode>(U));
8070
8071 case Instruction::Select:
8072 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8073 U->getOperand(2));
8074
8075 case Instruction::Call:
8076 case Instruction::Invoke:
8077 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8078 return getSCEV(RV);
8079
8080 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8081 switch (II->getIntrinsicID()) {
8082 case Intrinsic::abs:
8083 return getAbsExpr(
8084 getSCEV(II->getArgOperand(0)),
8085 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8086 case Intrinsic::umax:
8087 LHS = getSCEV(II->getArgOperand(0));
8088 RHS = getSCEV(II->getArgOperand(1));
8089 return getUMaxExpr(LHS, RHS);
8090 case Intrinsic::umin:
8091 LHS = getSCEV(II->getArgOperand(0));
8092 RHS = getSCEV(II->getArgOperand(1));
8093 return getUMinExpr(LHS, RHS);
8094 case Intrinsic::smax:
8095 LHS = getSCEV(II->getArgOperand(0));
8096 RHS = getSCEV(II->getArgOperand(1));
8097 return getSMaxExpr(LHS, RHS);
8098 case Intrinsic::smin:
8099 LHS = getSCEV(II->getArgOperand(0));
8100 RHS = getSCEV(II->getArgOperand(1));
8101 return getSMinExpr(LHS, RHS);
8102 case Intrinsic::usub_sat: {
8103 const SCEV *X = getSCEV(II->getArgOperand(0));
8104 const SCEV *Y = getSCEV(II->getArgOperand(1));
8105 const SCEV *ClampedY = getUMinExpr(X, Y);
8106 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8107 }
8108 case Intrinsic::uadd_sat: {
8109 const SCEV *X = getSCEV(II->getArgOperand(0));
8110 const SCEV *Y = getSCEV(II->getArgOperand(1));
8111 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8112 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8113 }
8114 case Intrinsic::start_loop_iterations:
8115 case Intrinsic::annotation:
8116 case Intrinsic::ptr_annotation:
8117 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8118 // just eqivalent to the first operand for SCEV purposes.
8119 return getSCEV(II->getArgOperand(0));
8120 case Intrinsic::vscale:
8121 return getVScale(II->getType());
8122 default:
8123 break;
8124 }
8125 }
8126 break;
8127 }
8128
8129 return getUnknown(V);
8130}
8131
8132//===----------------------------------------------------------------------===//
8133// Iteration Count Computation Code
8134//
8135
8137 if (isa<SCEVCouldNotCompute>(ExitCount))
8138 return getCouldNotCompute();
8139
8140 auto *ExitCountType = ExitCount->getType();
8141 assert(ExitCountType->isIntegerTy());
8142 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8143 1 + ExitCountType->getScalarSizeInBits());
8144 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8145}
8146
8148 Type *EvalTy,
8149 const Loop *L) {
8150 if (isa<SCEVCouldNotCompute>(ExitCount))
8151 return getCouldNotCompute();
8152
8153 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8154 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8155
8156 auto CanAddOneWithoutOverflow = [&]() {
8157 ConstantRange ExitCountRange =
8158 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8159 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8160 return true;
8161
8162 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8163 getMinusOne(ExitCount->getType()));
8164 };
8165
8166 // If we need to zero extend the backedge count, check if we can add one to
8167 // it prior to zero extending without overflow. Provided this is safe, it
8168 // allows better simplification of the +1.
8169 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8170 return getZeroExtendExpr(
8171 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8172
8173 // Get the total trip count from the count by adding 1. This may wrap.
8174 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8175}
8176
8177static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8178 if (!ExitCount)
8179 return 0;
8180
8181 ConstantInt *ExitConst = ExitCount->getValue();
8182
8183 // Guard against huge trip counts.
8184 if (ExitConst->getValue().getActiveBits() > 32)
8185 return 0;
8186
8187 // In case of integer overflow, this returns 0, which is correct.
8188 return ((unsigned)ExitConst->getZExtValue()) + 1;
8189}
8190
8192 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8193 return getConstantTripCount(ExitCount);
8194}
8195
8196unsigned
8198 const BasicBlock *ExitingBlock) {
8199 assert(ExitingBlock && "Must pass a non-null exiting block!");
8200 assert(L->isLoopExiting(ExitingBlock) &&
8201 "Exiting block must actually branch out of the loop!");
8202 const SCEVConstant *ExitCount =
8203 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8204 return getConstantTripCount(ExitCount);
8205}
8206
8208 const auto *MaxExitCount =
8209 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8210 return getConstantTripCount(MaxExitCount);
8211}
8212
8214 SmallVector<BasicBlock *, 8> ExitingBlocks;
8215 L->getExitingBlocks(ExitingBlocks);
8216
8217 std::optional<unsigned> Res;
8218 for (auto *ExitingBB : ExitingBlocks) {
8219 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8220 if (!Res)
8221 Res = Multiple;
8222 Res = (unsigned)std::gcd(*Res, Multiple);
8223 }
8224 return Res.value_or(1);
8225}
8226
8228 const SCEV *ExitCount) {
8229 if (ExitCount == getCouldNotCompute())
8230 return 1;
8231
8232 // Get the trip count
8233 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8234
8235 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8236 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8237 // the greatest power of 2 divisor less than 2^32.
8238 return Multiple.getActiveBits() > 32
8239 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8240 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8241}
8242
8243/// Returns the largest constant divisor of the trip count of this loop as a
8244/// normal unsigned value, if possible. This means that the actual trip count is
8245/// always a multiple of the returned value (don't forget the trip count could
8246/// very well be zero as well!).
8247///
8248/// Returns 1 if the trip count is unknown or not guaranteed to be the
8249/// multiple of a constant (which is also the case if the trip count is simply
8250/// constant, use getSmallConstantTripCount for that case), Will also return 1
8251/// if the trip count is very large (>= 2^32).
8252///
8253/// As explained in the comments for getSmallConstantTripCount, this assumes
8254/// that control exits the loop via ExitingBlock.
8255unsigned
8257 const BasicBlock *ExitingBlock) {
8258 assert(ExitingBlock && "Must pass a non-null exiting block!");
8259 assert(L->isLoopExiting(ExitingBlock) &&
8260 "Exiting block must actually branch out of the loop!");
8261 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8262 return getSmallConstantTripMultiple(L, ExitCount);
8263}
8264
8266 const BasicBlock *ExitingBlock,
8267 ExitCountKind Kind) {
8268 switch (Kind) {
8269 case Exact:
8270 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8271 case SymbolicMaximum:
8272 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8273 case ConstantMaximum:
8274 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8275 };
8276 llvm_unreachable("Invalid ExitCountKind!");
8277}
8278
8279const SCEV *
8282 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8283}
8284
8286 ExitCountKind Kind) {
8287 switch (Kind) {
8288 case Exact:
8289 return getBackedgeTakenInfo(L).getExact(L, this);
8290 case ConstantMaximum:
8291 return getBackedgeTakenInfo(L).getConstantMax(this);
8292 case SymbolicMaximum:
8293 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8294 };
8295 llvm_unreachable("Invalid ExitCountKind!");
8296}
8297
8299 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8300}
8301
8302/// Push PHI nodes in the header of the given loop onto the given Worklist.
8303static void PushLoopPHIs(const Loop *L,
8306 BasicBlock *Header = L->getHeader();
8307
8308 // Push all Loop-header PHIs onto the Worklist stack.
8309 for (PHINode &PN : Header->phis())
8310 if (Visited.insert(&PN).second)
8311 Worklist.push_back(&PN);
8312}
8313
8314const ScalarEvolution::BackedgeTakenInfo &
8315ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8316 auto &BTI = getBackedgeTakenInfo(L);
8317 if (BTI.hasFullInfo())
8318 return BTI;
8319
8320 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8321
8322 if (!Pair.second)
8323 return Pair.first->second;
8324
8325 BackedgeTakenInfo Result =
8326 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8327
8328 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8329}
8330
8331ScalarEvolution::BackedgeTakenInfo &
8332ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8333 // Initially insert an invalid entry for this loop. If the insertion
8334 // succeeds, proceed to actually compute a backedge-taken count and
8335 // update the value. The temporary CouldNotCompute value tells SCEV
8336 // code elsewhere that it shouldn't attempt to request a new
8337 // backedge-taken count, which could result in infinite recursion.
8338 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8339 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8340 if (!Pair.second)
8341 return Pair.first->second;
8342
8343 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8344 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8345 // must be cleared in this scope.
8346 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8347
8348 // Now that we know more about the trip count for this loop, forget any
8349 // existing SCEV values for PHI nodes in this loop since they are only
8350 // conservative estimates made without the benefit of trip count
8351 // information. This invalidation is not necessary for correctness, and is
8352 // only done to produce more precise results.
8353 if (Result.hasAnyInfo()) {
8354 // Invalidate any expression using an addrec in this loop.
8356 auto LoopUsersIt = LoopUsers.find(L);
8357 if (LoopUsersIt != LoopUsers.end())
8358 append_range(ToForget, LoopUsersIt->second);
8359 forgetMemoizedResults(ToForget);
8360
8361 // Invalidate constant-evolved loop header phis.
8362 for (PHINode &PN : L->getHeader()->phis())
8363 ConstantEvolutionLoopExitValue.erase(&PN);
8364 }
8365
8366 // Re-lookup the insert position, since the call to
8367 // computeBackedgeTakenCount above could result in a
8368 // recusive call to getBackedgeTakenInfo (on a different
8369 // loop), which would invalidate the iterator computed
8370 // earlier.
8371 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8372}
8373
8375 // This method is intended to forget all info about loops. It should
8376 // invalidate caches as if the following happened:
8377 // - The trip counts of all loops have changed arbitrarily
8378 // - Every llvm::Value has been updated in place to produce a different
8379 // result.
8380 BackedgeTakenCounts.clear();
8381 PredicatedBackedgeTakenCounts.clear();
8382 BECountUsers.clear();
8383 LoopPropertiesCache.clear();
8384 ConstantEvolutionLoopExitValue.clear();
8385 ValueExprMap.clear();
8386 ValuesAtScopes.clear();
8387 ValuesAtScopesUsers.clear();
8388 LoopDispositions.clear();
8389 BlockDispositions.clear();
8390 UnsignedRanges.clear();
8391 SignedRanges.clear();
8392 ExprValueMap.clear();
8393 HasRecMap.clear();
8394 ConstantMultipleCache.clear();
8395 PredicatedSCEVRewrites.clear();
8396 FoldCache.clear();
8397 FoldCacheUser.clear();
8398}
8399void ScalarEvolution::visitAndClearUsers(
8403 while (!Worklist.empty()) {
8404 Instruction *I = Worklist.pop_back_val();
8405 if (!isSCEVable(I->getType()))
8406 continue;
8407
8409 ValueExprMap.find_as(static_cast<Value *>(I));
8410 if (It != ValueExprMap.end()) {
8411 eraseValueFromMap(It->first);
8412 ToForget.push_back(It->second);
8413 if (PHINode *PN = dyn_cast<PHINode>(I))
8414 ConstantEvolutionLoopExitValue.erase(PN);
8415 }
8416
8417 PushDefUseChildren(I, Worklist, Visited);
8418 }
8419}
8420
8422 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8426
8427 // Iterate over all the loops and sub-loops to drop SCEV information.
8428 while (!LoopWorklist.empty()) {
8429 auto *CurrL = LoopWorklist.pop_back_val();
8430
8431 // Drop any stored trip count value.
8432 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8433 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8434
8435 // Drop information about predicated SCEV rewrites for this loop.
8436 for (auto I = PredicatedSCEVRewrites.begin();
8437 I != PredicatedSCEVRewrites.end();) {
8438 std::pair<const SCEV *, const Loop *> Entry = I->first;
8439 if (Entry.second == CurrL)
8440 PredicatedSCEVRewrites.erase(I++);
8441 else
8442 ++I;
8443 }
8444
8445 auto LoopUsersItr = LoopUsers.find(CurrL);
8446 if (LoopUsersItr != LoopUsers.end()) {
8447 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8448 LoopUsersItr->second.end());
8449 }
8450
8451 // Drop information about expressions based on loop-header PHIs.
8452 PushLoopPHIs(CurrL, Worklist, Visited);
8453 visitAndClearUsers(Worklist, Visited, ToForget);
8454
8455 LoopPropertiesCache.erase(CurrL);
8456 // Forget all contained loops too, to avoid dangling entries in the
8457 // ValuesAtScopes map.
8458 LoopWorklist.append(CurrL->begin(), CurrL->end());
8459 }
8460 forgetMemoizedResults(ToForget);
8461}
8462
8464 forgetLoop(L->getOutermostLoop());
8465}
8466
8468 Instruction *I = dyn_cast<Instruction>(V);
8469 if (!I) return;
8470
8471 // Drop information about expressions based on loop-header PHIs.
8475 Worklist.push_back(I);
8476 Visited.insert(I);
8477 visitAndClearUsers(Worklist, Visited, ToForget);
8478
8479 forgetMemoizedResults(ToForget);
8480}
8481
8483 if (!isSCEVable(V->getType()))
8484 return;
8485
8486 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8487 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8488 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8489 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8490 if (const SCEV *S = getExistingSCEV(V)) {
8491 struct InvalidationRootCollector {
8492 Loop *L;
8494
8495 InvalidationRootCollector(Loop *L) : L(L) {}
8496
8497 bool follow(const SCEV *S) {
8498 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8499 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8500 if (L->contains(I))
8501 Roots.push_back(S);
8502 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8503 if (L->contains(AddRec->getLoop()))
8504 Roots.push_back(S);
8505 }
8506 return true;
8507 }
8508 bool isDone() const { return false; }
8509 };
8510
8511 InvalidationRootCollector C(L);
8512 visitAll(S, C);
8513 forgetMemoizedResults(C.Roots);
8514 }
8515
8516 // Also perform the normal invalidation.
8517 forgetValue(V);
8518}
8519
8520void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8521
8523 // Unless a specific value is passed to invalidation, completely clear both
8524 // caches.
8525 if (!V) {
8526 BlockDispositions.clear();
8527 LoopDispositions.clear();
8528 return;
8529 }
8530
8531 if (!isSCEVable(V->getType()))
8532 return;
8533
8534 const SCEV *S = getExistingSCEV(V);
8535 if (!S)
8536 return;
8537
8538 // Invalidate the block and loop dispositions cached for S. Dispositions of
8539 // S's users may change if S's disposition changes (i.e. a user may change to
8540 // loop-invariant, if S changes to loop invariant), so also invalidate
8541 // dispositions of S's users recursively.
8542 SmallVector<const SCEV *, 8> Worklist = {S};
8544 while (!Worklist.empty()) {
8545 const SCEV *Curr = Worklist.pop_back_val();
8546 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8547 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8548 if (!LoopDispoRemoved && !BlockDispoRemoved)
8549 continue;
8550 auto Users = SCEVUsers.find(Curr);
8551 if (Users != SCEVUsers.end())
8552 for (const auto *User : Users->second)
8553 if (Seen.insert(User).second)
8554 Worklist.push_back(User);
8555 }
8556}
8557
8558/// Get the exact loop backedge taken count considering all loop exits. A
8559/// computable result can only be returned for loops with all exiting blocks
8560/// dominating the latch. howFarToZero assumes that the limit of each loop test
8561/// is never skipped. This is a valid assumption as long as the loop exits via
8562/// that test. For precise results, it is the caller's responsibility to specify
8563/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8564const SCEV *
8565ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
8567 // If any exits were not computable, the loop is not computable.
8568 if (!isComplete() || ExitNotTaken.empty())
8569 return SE->getCouldNotCompute();
8570
8571 const BasicBlock *Latch = L->getLoopLatch();
8572 // All exiting blocks we have collected must dominate the only backedge.
8573 if (!Latch)
8574 return SE->getCouldNotCompute();
8575
8576 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8577 // count is simply a minimum out of all these calculated exit counts.
8579 for (const auto &ENT : ExitNotTaken) {
8580 const SCEV *BECount = ENT.ExactNotTaken;
8581 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8582 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8583 "We should only have known counts for exiting blocks that dominate "
8584 "latch!");
8585
8586 Ops.push_back(BECount);
8587
8588 if (Preds)
8589 for (const auto *P : ENT.Predicates)
8590 Preds->push_back(P);
8591
8592 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8593 "Predicate should be always true!");
8594 }
8595
8596 // If an earlier exit exits on the first iteration (exit count zero), then
8597 // a later poison exit count should not propagate into the result. This are
8598 // exactly the semantics provided by umin_seq.
8599 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8600}
8601
8602/// Get the exact not taken count for this loop exit.
8603const SCEV *
8604ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8605 ScalarEvolution *SE) const {
8606 for (const auto &ENT : ExitNotTaken)
8607 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8608 return ENT.ExactNotTaken;
8609
8610 return SE->getCouldNotCompute();
8611}
8612
8613const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8614 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8615 for (const auto &ENT : ExitNotTaken)
8616 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8617 return ENT.ConstantMaxNotTaken;
8618
8619 return SE->getCouldNotCompute();
8620}
8621
8622const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8623 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8624 for (const auto &ENT : ExitNotTaken)
8625 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8626 return ENT.SymbolicMaxNotTaken;
8627
8628 return SE->getCouldNotCompute();
8629}
8630
8631/// getConstantMax - Get the constant max backedge taken count for the loop.
8632const SCEV *
8633ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8634 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8635 return !ENT.hasAlwaysTruePredicate();
8636 };
8637
8638 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8639 return SE->getCouldNotCompute();
8640
8641 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8642 isa<SCEVConstant>(getConstantMax())) &&
8643 "No point in having a non-constant max backedge taken count!");
8644 return getConstantMax();
8645}
8646
8647const SCEV *
8648ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
8649 ScalarEvolution *SE) {
8650 if (!SymbolicMax)
8651 SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
8652 return SymbolicMax;
8653}
8654
8655bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8656 ScalarEvolution *SE) const {
8657 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8658 return !ENT.hasAlwaysTruePredicate();
8659 };
8660 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8661}
8662
8664 : ExitLimit(E, E, E, false, std::nullopt) {}
8665
8667 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8668 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8670 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8671 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8672 // If we prove the max count is zero, so is the symbolic bound. This happens
8673 // in practice due to differences in a) how context sensitive we've chosen
8674 // to be and b) how we reason about bounds implied by UB.
8675 if (ConstantMaxNotTaken->isZero()) {
8677 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8678 }
8679
8680 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8681 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8682 "Exact is not allowed to be less precise than Constant Max");
8683 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8684 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8685 "Exact is not allowed to be less precise than Symbolic Max");
8686 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8687 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8688 "Symbolic Max is not allowed to be less precise than Constant Max");
8689 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8690 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8691 "No point in having a non-constant max backedge taken count!");
8692 for (const auto *PredSet : PredSetList)
8693 for (const auto *P : *PredSet)
8694 addPredicate(P);
8695 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8696 "Backedge count should be int");
8697 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8699 "Max backedge count should be int");
8700}
8701
8703 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8704 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8706 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8707 { &PredSet }) {}
8708
8709/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8710/// computable exit into a persistent ExitNotTakenInfo array.
8711ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8713 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8714 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8715 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8716
8717 ExitNotTaken.reserve(ExitCounts.size());
8718 std::transform(ExitCounts.begin(), ExitCounts.end(),
8719 std::back_inserter(ExitNotTaken),
8720 [&](const EdgeExitInfo &EEI) {
8721 BasicBlock *ExitBB = EEI.first;
8722 const ExitLimit &EL = EEI.second;
8723 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8724 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8725 EL.Predicates);
8726 });
8727 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8728 isa<SCEVConstant>(ConstantMax)) &&
8729 "No point in having a non-constant max backedge taken count!");
8730}
8731
8732/// Compute the number of times the backedge of the specified loop will execute.
8733ScalarEvolution::BackedgeTakenInfo
8734ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8735 bool AllowPredicates) {
8736 SmallVector<BasicBlock *, 8> ExitingBlocks;
8737 L->getExitingBlocks(ExitingBlocks);
8738
8739 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8740
8742 bool CouldComputeBECount = true;
8743 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8744 const SCEV *MustExitMaxBECount = nullptr;
8745 const SCEV *MayExitMaxBECount = nullptr;
8746 bool MustExitMaxOrZero = false;
8747
8748 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8749 // and compute maxBECount.
8750 // Do a union of all the predicates here.
8751 for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
8752 BasicBlock *ExitBB = ExitingBlocks[i];
8753
8754 // We canonicalize untaken exits to br (constant), ignore them so that
8755 // proving an exit untaken doesn't negatively impact our ability to reason
8756 // about the loop as whole.
8757 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8758 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8759 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8760 if (ExitIfTrue == CI->isZero())
8761 continue;
8762 }
8763
8764 ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
8765
8766 assert((AllowPredicates || EL.Predicates.empty()) &&
8767 "Predicated exit limit when predicates are not allowed!");
8768
8769 // 1. For each exit that can be computed, add an entry to ExitCounts.
8770 // CouldComputeBECount is true only if all exits can be computed.
8771 if (EL.ExactNotTaken != getCouldNotCompute())
8772 ++NumExitCountsComputed;
8773 else
8774 // We couldn't compute an exact value for this exit, so
8775 // we won't be able to compute an exact value for the loop.
8776 CouldComputeBECount = false;
8777 // Remember exit count if either exact or symbolic is known. Because
8778 // Exact always implies symbolic, only check symbolic.
8779 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8780 ExitCounts.emplace_back(ExitBB, EL);
8781 else {
8782 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8783 "Exact is known but symbolic isn't?");
8784 ++NumExitCountsNotComputed;
8785 }
8786
8787 // 2. Derive the loop's MaxBECount from each exit's max number of
8788 // non-exiting iterations. Partition the loop exits into two kinds:
8789 // LoopMustExits and LoopMayExits.
8790 //
8791 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8792 // is a LoopMayExit. If any computable LoopMustExit is found, then
8793 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8794 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8795 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8796 // any
8797 // computable EL.ConstantMaxNotTaken.
8798 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8799 DT.dominates(ExitBB, Latch)) {
8800 if (!MustExitMaxBECount) {
8801 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8802 MustExitMaxOrZero = EL.MaxOrZero;
8803 } else {
8804 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8805 EL.ConstantMaxNotTaken);
8806 }
8807 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8808 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8809 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8810 else {
8811 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8812 EL.ConstantMaxNotTaken);
8813 }
8814 }
8815 }
8816 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8817 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8818 // The loop backedge will be taken the maximum or zero times if there's
8819 // a single exit that must be taken the maximum or zero times.
8820 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8821
8822 // Remember which SCEVs are used in exit limits for invalidation purposes.
8823 // We only care about non-constant SCEVs here, so we can ignore
8824 // EL.ConstantMaxNotTaken
8825 // and MaxBECount, which must be SCEVConstant.
8826 for (const auto &Pair : ExitCounts) {
8827 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8828 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8829 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8830 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8831 {L, AllowPredicates});
8832 }
8833 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8834 MaxBECount, MaxOrZero);
8835}
8836
8838ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8839 bool AllowPredicates) {
8840 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8841 // If our exiting block does not dominate the latch, then its connection with
8842 // loop's exit limit may be far from trivial.
8843 const BasicBlock *Latch = L->getLoopLatch();
8844 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8845 return getCouldNotCompute();
8846
8847 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
8848 Instruction *Term = ExitingBlock->getTerminator();
8849 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8850 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8851 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8852 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8853 "It should have one successor in loop and one exit block!");
8854 // Proceed to the next level to examine the exit condition expression.
8855 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8856 /*ControlsOnlyExit=*/IsOnlyExit,
8857 AllowPredicates);
8858 }
8859
8860 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8861 // For switch, make sure that there is a single exit from the loop.
8862 BasicBlock *Exit = nullptr;
8863 for (auto *SBB : successors(ExitingBlock))
8864 if (!L->contains(SBB)) {
8865 if (Exit) // Multiple exit successors.
8866 return getCouldNotCompute();
8867 Exit = SBB;
8868 }
8869 assert(Exit && "Exiting block must have at least one exit");
8870 return computeExitLimitFromSingleExitSwitch(
8871 L, SI, Exit,
8872 /*ControlsOnlyExit=*/IsOnlyExit);
8873 }
8874
8875 return getCouldNotCompute();
8876}
8877
8879 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8880 bool AllowPredicates) {
8881 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8882 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8883 ControlsOnlyExit, AllowPredicates);
8884}
8885
8886std::optional<ScalarEvolution::ExitLimit>
8887ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8888 bool ExitIfTrue, bool ControlsOnlyExit,
8889 bool AllowPredicates) {
8890 (void)this->L;
8891 (void)this->ExitIfTrue;
8892 (void)this->AllowPredicates;
8893
8894 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8895 this->AllowPredicates == AllowPredicates &&
8896 "Variance in assumed invariant key components!");
8897 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8898 if (Itr == TripCountMap.end())
8899 return std::nullopt;
8900 return Itr->second;
8901}
8902
8903void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8904 bool ExitIfTrue,
8905 bool ControlsOnlyExit,
8906 bool AllowPredicates,
8907 const ExitLimit &EL) {
8908 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8909 this->AllowPredicates == AllowPredicates &&
8910 "Variance in assumed invariant key components!");
8911
8912 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
8913 assert(InsertResult.second && "Expected successful insertion!");
8914 (void)InsertResult;
8915 (void)ExitIfTrue;
8916}
8917
8918ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8919 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8920 bool ControlsOnlyExit, bool AllowPredicates) {
8921
8922 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8923 AllowPredicates))
8924 return *MaybeEL;
8925
8926 ExitLimit EL = computeExitLimitFromCondImpl(
8927 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
8928 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8929 return EL;
8930}
8931
8932ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8933 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8934 bool ControlsOnlyExit, bool AllowPredicates) {
8935 // Handle BinOp conditions (And, Or).
8936 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8937 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
8938 return *LimitFromBinOp;
8939
8940 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8941 // Proceed to the next level to examine the icmp.
8942 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8943 ExitLimit EL =
8944 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
8945 if (EL.hasFullInfo() || !AllowPredicates)
8946 return EL;
8947
8948 // Try again, but use SCEV predicates this time.
8949 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
8950 ControlsOnlyExit,
8951 /*AllowPredicates=*/true);
8952 }
8953
8954 // Check for a constant condition. These are normally stripped out by
8955 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8956 // preserve the CFG and is temporarily leaving constant conditions
8957 // in place.
8958 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8959 if (ExitIfTrue == !CI->getZExtValue())
8960 // The backedge is always taken.
8961 return getCouldNotCompute();
8962 // The backedge is never taken.
8963 return getZero(CI->getType());
8964 }
8965
8966 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
8967 // with a constant step, we can form an equivalent icmp predicate and figure
8968 // out how many iterations will be taken before we exit.
8969 const WithOverflowInst *WO;
8970 const APInt *C;
8971 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
8972 match(WO->getRHS(), m_APInt(C))) {
8973 ConstantRange NWR =
8975 WO->getNoWrapKind());
8976 CmpInst::Predicate Pred;
8977 APInt NewRHSC, Offset;
8978 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
8979 if (!ExitIfTrue)
8980 Pred = ICmpInst::getInversePredicate(Pred);
8981 auto *LHS = getSCEV(WO->getLHS());
8982 if (Offset != 0)
8984 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
8985 ControlsOnlyExit, AllowPredicates);
8986 if (EL.hasAnyInfo())
8987 return EL;
8988 }
8989
8990 // If it's not an integer or pointer comparison then compute it the hard way.
8991 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8992}
8993
8994std::optional<ScalarEvolution::ExitLimit>
8995ScalarEvolution::computeExitLimitFromCondFromBinOp(
8996 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8997 bool ControlsOnlyExit, bool AllowPredicates) {
8998 // Check if the controlling expression for this loop is an And or Or.
8999 Value *Op0, *Op1;
9000 bool IsAnd = false;
9001 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9002 IsAnd = true;
9003 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9004 IsAnd = false;
9005 else
9006 return std::nullopt;
9007
9008 // EitherMayExit is true in these two cases:
9009 // br (and Op0 Op1), loop, exit
9010 // br (or Op0 Op1), exit, loop
9011 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9012 ExitLimit EL0 = computeExitLimitFromCondCached(
9013 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9014 AllowPredicates);
9015 ExitLimit EL1 = computeExitLimitFromCondCached(
9016 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9017 AllowPredicates);
9018
9019 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9020 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9021 if (isa<ConstantInt>(Op1))
9022 return Op1 == NeutralElement ? EL0 : EL1;
9023 if (isa<ConstantInt>(Op0))
9024 return Op0 == NeutralElement ? EL1 : EL0;
9025
9026 const SCEV *BECount = getCouldNotCompute();
9027 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9028 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9029 if (EitherMayExit) {
9030 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9031 // Both conditions must be same for the loop to continue executing.
9032 // Choose the less conservative count.
9033 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9034 EL1.ExactNotTaken != getCouldNotCompute()) {
9035 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9036 UseSequentialUMin);
9037 }
9038 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9039 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9040 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9041 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9042 else
9043 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9044 EL1.ConstantMaxNotTaken);
9045 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9046 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9047 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9048 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9049 else
9050 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9051 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9052 } else {
9053 // Both conditions must be same at the same time for the loop to exit.
9054 // For now, be conservative.
9055 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9056 BECount = EL0.ExactNotTaken;
9057 }
9058
9059 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9060 // to be more aggressive when computing BECount than when computing
9061 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9062 // and
9063 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9064 // EL1.ConstantMaxNotTaken to not.
9065 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9066 !isa<SCEVCouldNotCompute>(BECount))
9067 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9068 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9069 SymbolicMaxBECount =
9070 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9071 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9072 { &EL0.Predicates, &EL1.Predicates });
9073}
9074
9075ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9076 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9077 bool AllowPredicates) {
9078 // If the condition was exit on true, convert the condition to exit on false
9080 if (!ExitIfTrue)
9081 Pred = ExitCond->getPredicate();
9082 else
9083 Pred = ExitCond->getInversePredicate();
9084 const ICmpInst::Predicate OriginalPred = Pred;
9085
9086 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9087 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9088
9089 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9090 AllowPredicates);
9091 if (EL.hasAnyInfo())
9092 return EL;
9093
9094 auto *ExhaustiveCount =
9095 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9096
9097 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9098 return ExhaustiveCount;
9099
9100 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9101 ExitCond->getOperand(1), L, OriginalPred);
9102}
9103ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9104 const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9105 bool ControlsOnlyExit, bool AllowPredicates) {
9106
9107 // Try to evaluate any dependencies out of the loop.
9108 LHS = getSCEVAtScope(LHS, L);
9109 RHS = getSCEVAtScope(RHS, L);
9110
9111 // At this point, we would like to compute how many iterations of the
9112 // loop the predicate will return true for these inputs.
9113 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9114 // If there is a loop-invariant, force it into the RHS.
9115 std::swap(LHS, RHS);
9116 Pred = ICmpInst::getSwappedPredicate(Pred);
9117 }
9118
9119 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9121 // Simplify the operands before analyzing them.
9122 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9123
9124 // If we have a comparison of a chrec against a constant, try to use value
9125 // ranges to answer this query.
9126 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9127 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9128 if (AddRec->getLoop() == L) {
9129 // Form the constant range.
9130 ConstantRange CompRange =
9131 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9132
9133 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9134 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9135 }
9136
9137 // If this loop must exit based on this condition (or execute undefined
9138 // behaviour), and we can prove the test sequence produced must repeat
9139 // the same values on self-wrap of the IV, then we can infer that IV
9140 // doesn't self wrap because if it did, we'd have an infinite (undefined)
9141 // loop.
9142 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9143 // TODO: We can peel off any functions which are invertible *in L*. Loop
9144 // invariant terms are effectively constants for our purposes here.
9145 auto *InnerLHS = LHS;
9146 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9147 InnerLHS = ZExt->getOperand();
9148 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
9149 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
9150 if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9151 StrideC && StrideC->getAPInt().isPowerOf2()) {
9152 auto Flags = AR->getNoWrapFlags();
9153 Flags = setFlags(Flags, SCEV::FlagNW);
9156 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9157 }
9158 }
9159 }
9160
9161 switch (Pred) {
9162 case ICmpInst::ICMP_NE: { // while (X != Y)
9163 // Convert to: while (X-Y != 0)
9164 if (LHS->getType()->isPointerTy()) {
9166 if (isa<SCEVCouldNotCompute>(LHS))
9167 return LHS;
9168 }
9169 if (RHS->getType()->isPointerTy()) {
9171 if (isa<SCEVCouldNotCompute>(RHS))
9172 return RHS;
9173 }
9174 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9175 AllowPredicates);
9176 if (EL.hasAnyInfo())
9177 return EL;
9178 break;
9179 }
9180 case ICmpInst::ICMP_EQ: { // while (X == Y)
9181 // Convert to: while (X-Y == 0)
9182 if (LHS->getType()->isPointerTy()) {
9184 if (isa<SCEVCouldNotCompute>(LHS))
9185 return LHS;
9186 }
9187 if (RHS->getType()->isPointerTy()) {
9189 if (isa<SCEVCouldNotCompute>(RHS))
9190 return RHS;
9191 }
9192 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9193 if (EL.hasAnyInfo()) return EL;
9194 break;
9195 }
9196 case ICmpInst::ICMP_SLE:
9197 case ICmpInst::ICMP_ULE:
9198 // Since the loop is finite, an invariant RHS cannot include the boundary
9199 // value, otherwise it would loop forever.
9200 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9201 !isLoopInvariant(RHS, L))
9202 break;
9203 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9204 [[fallthrough]];
9205 case ICmpInst::ICMP_SLT:
9206 case ICmpInst::ICMP_ULT: { // while (X < Y)
9207 bool IsSigned = ICmpInst::isSigned(Pred);
9208 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9209 AllowPredicates);
9210 if (EL.hasAnyInfo())
9211 return EL;
9212 break;
9213 }
9214 case ICmpInst::ICMP_SGE:
9215 case ICmpInst::ICMP_UGE:
9216 // Since the loop is finite, an invariant RHS cannot include the boundary
9217 // value, otherwise it would loop forever.
9218 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9219 !isLoopInvariant(RHS, L))
9220 break;
9221 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9222 [[fallthrough]];
9223 case ICmpInst::ICMP_SGT:
9224 case ICmpInst::ICMP_UGT: { // while (X > Y)
9225 bool IsSigned = ICmpInst::isSigned(Pred);
9226 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9227 AllowPredicates);
9228 if (EL.hasAnyInfo())
9229 return EL;
9230 break;
9231 }
9232 default:
9233 break;
9234 }
9235
9236 return getCouldNotCompute();
9237}
9238
9240ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9241 SwitchInst *Switch,
9242 BasicBlock *ExitingBlock,
9243 bool ControlsOnlyExit) {
9244 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9245
9246 // Give up if the exit is the default dest of a switch.
9247 if (Switch->getDefaultDest() == ExitingBlock)
9248 return getCouldNotCompute();
9249
9250 assert(L->contains(Switch->getDefaultDest()) &&
9251 "Default case must not exit the loop!");
9252 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9253 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9254
9255 // while (X != Y) --> while (X-Y != 0)
9256 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9257 if (EL.hasAnyInfo())
9258 return EL;
9259
9260 return getCouldNotCompute();
9261}
9262
9263static ConstantInt *
9265 ScalarEvolution &SE) {
9266 const SCEV *InVal = SE.getConstant(C);
9267 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9268 assert(isa<SCEVConstant>(Val) &&
9269 "Evaluation of SCEV at constant didn't fold correctly?");
9270 return cast<SCEVConstant>(Val)->getValue();
9271}
9272
9273ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9274 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9275 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9276 if (!RHS)
9277 return getCouldNotCompute();
9278
9279 const BasicBlock *Latch = L->getLoopLatch();
9280 if (!Latch)
9281 return getCouldNotCompute();
9282
9283 const BasicBlock *Predecessor = L->getLoopPredecessor();
9284 if (!Predecessor)
9285 return getCouldNotCompute();
9286
9287 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9288 // Return LHS in OutLHS and shift_opt in OutOpCode.
9289 auto MatchPositiveShift =
9290 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9291
9292 using namespace PatternMatch;
9293
9294 ConstantInt *ShiftAmt;
9295 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9296 OutOpCode = Instruction::LShr;
9297 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9298 OutOpCode = Instruction::AShr;
9299 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9300 OutOpCode = Instruction::Shl;
9301 else
9302 return false;
9303
9304 return ShiftAmt->getValue().isStrictlyPositive();
9305 };
9306
9307 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9308 //
9309 // loop:
9310 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9311 // %iv.shifted = lshr i32 %iv, <positive constant>
9312 //
9313 // Return true on a successful match. Return the corresponding PHI node (%iv
9314 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9315 auto MatchShiftRecurrence =
9316 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9317 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9318
9319 {
9321 Value *V;
9322
9323 // If we encounter a shift instruction, "peel off" the shift operation,
9324 // and remember that we did so. Later when we inspect %iv's backedge
9325 // value, we will make sure that the backedge value uses the same
9326 // operation.
9327 //
9328 // Note: the peeled shift operation does not have to be the same
9329 // instruction as the one feeding into the PHI's backedge value. We only
9330 // really care about it being the same *kind* of shift instruction --
9331 // that's all that is required for our later inferences to hold.
9332 if (MatchPositiveShift(LHS, V, OpC)) {
9333 PostShiftOpCode = OpC;
9334 LHS = V;
9335 }
9336 }
9337
9338 PNOut = dyn_cast<PHINode>(LHS);
9339 if (!PNOut || PNOut->getParent() != L->getHeader())
9340 return false;
9341
9342 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9343 Value *OpLHS;
9344
9345 return
9346 // The backedge value for the PHI node must be a shift by a positive
9347 // amount
9348 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9349
9350 // of the PHI node itself
9351 OpLHS == PNOut &&
9352
9353 // and the kind of shift should be match the kind of shift we peeled
9354 // off, if any.
9355 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9356 };
9357
9358 PHINode *PN;
9360 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9361 return getCouldNotCompute();
9362
9363 const DataLayout &DL = getDataLayout();
9364
9365 // The key rationale for this optimization is that for some kinds of shift
9366 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9367 // within a finite number of iterations. If the condition guarding the
9368 // backedge (in the sense that the backedge is taken if the condition is true)
9369 // is false for the value the shift recurrence stabilizes to, then we know
9370 // that the backedge is taken only a finite number of times.
9371
9372 ConstantInt *StableValue = nullptr;
9373 switch (OpCode) {
9374 default:
9375 llvm_unreachable("Impossible case!");
9376
9377 case Instruction::AShr: {
9378 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9379 // bitwidth(K) iterations.
9380 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9381 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9382 Predecessor->getTerminator(), &DT);
9383 auto *Ty = cast<IntegerType>(RHS->getType());
9384 if (Known.isNonNegative())
9385 StableValue = ConstantInt::get(Ty, 0);
9386 else if (Known.isNegative())
9387 StableValue = ConstantInt::get(Ty, -1, true);
9388 else
9389 return getCouldNotCompute();
9390
9391 break;
9392 }
9393 case Instruction::LShr:
9394 case Instruction::Shl:
9395 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9396 // stabilize to 0 in at most bitwidth(K) iterations.
9397 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9398 break;
9399 }
9400
9401 auto *Result =
9402 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9403 assert(Result->getType()->isIntegerTy(1) &&
9404 "Otherwise cannot be an operand to a branch instruction");
9405
9406 if (Result->isZeroValue()) {
9407 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9408 const SCEV *UpperBound =
9410 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9411 }
9412
9413 return getCouldNotCompute();
9414}
9415
9416/// Return true if we can constant fold an instruction of the specified type,
9417/// assuming that all operands were constants.
9418static bool CanConstantFold(const Instruction *I) {
9419 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9420 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9421 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9422 return true;
9423
9424 if (const CallInst *CI = dyn_cast<CallInst>(I))
9425 if (const Function *F = CI->getCalledFunction())
9426 return canConstantFoldCallTo(CI, F);
9427 return false;
9428}
9429
9430/// Determine whether this instruction can constant evolve within this loop
9431/// assuming its operands can all constant evolve.
9432static bool canConstantEvolve(Instruction *I, const Loop *L) {
9433 // An instruction outside of the loop can't be derived from a loop PHI.
9434 if (!L->contains(I)) return false;
9435
9436 if (isa<PHINode>(I)) {
9437 // We don't currently keep track of the control flow needed to evaluate
9438 // PHIs, so we cannot handle PHIs inside of loops.
9439 return L->getHeader() == I->getParent();
9440 }
9441
9442 // If we won't be able to constant fold this expression even if the operands
9443 // are constants, bail early.
9444 return CanConstantFold(I);
9445}
9446
9447/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9448/// recursing through each instruction operand until reaching a loop header phi.
9449static PHINode *
9452 unsigned Depth) {
9454 return nullptr;
9455
9456 // Otherwise, we can evaluate this instruction if all of its operands are
9457 // constant or derived from a PHI node themselves.
9458 PHINode *PHI = nullptr;
9459 for (Value *Op : UseInst->operands()) {
9460 if (isa<Constant>(Op)) continue;
9461
9462 Instruction *OpInst = dyn_cast<Instruction>(Op);
9463 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9464
9465 PHINode *P = dyn_cast<PHINode>(OpInst);
9466 if (!P)
9467 // If this operand is already visited, reuse the prior result.
9468 // We may have P != PHI if this is the deepest point at which the
9469 // inconsistent paths meet.
9470 P = PHIMap.lookup(OpInst);
9471 if (!P) {
9472 // Recurse and memoize the results, whether a phi is found or not.
9473 // This recursive call invalidates pointers into PHIMap.
9474 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9475 PHIMap[OpInst] = P;
9476 }
9477 if (!P)
9478 return nullptr; // Not evolving from PHI
9479 if (PHI && PHI != P)
9480 return nullptr; // Evolving from multiple different PHIs.
9481 PHI = P;
9482 }
9483 // This is a expression evolving from a constant PHI!
9484 return PHI;
9485}
9486
9487/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9488/// in the loop that V is derived from. We allow arbitrary operations along the
9489/// way, but the operands of an operation must either be constants or a value
9490/// derived from a constant PHI. If this expression does not fit with these
9491/// constraints, return null.
9493 Instruction *I = dyn_cast<Instruction>(V);
9494 if (!I || !canConstantEvolve(I, L)) return nullptr;
9495
9496 if (PHINode *PN = dyn_cast<PHINode>(I))
9497 return PN;
9498
9499 // Record non-constant instructions contained by the loop.
9501 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9502}
9503
9504/// EvaluateExpression - Given an expression that passes the
9505/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9506/// in the loop has the value PHIVal. If we can't fold this expression for some
9507/// reason, return null.
9510 const DataLayout &DL,
9511 const TargetLibraryInfo *TLI) {
9512 // Convenient constant check, but redundant for recursive calls.
9513 if (Constant *C = dyn_cast<Constant>(V)) return C;
9514 Instruction *I = dyn_cast<Instruction>(V);
9515 if (!I) return nullptr;
9516
9517 if (Constant *C = Vals.lookup(I)) return C;
9518
9519 // An instruction inside the loop depends on a value outside the loop that we
9520 // weren't given a mapping for, or a value such as a call inside the loop.
9521 if (!canConstantEvolve(I, L)) return nullptr;
9522
9523 // An unmapped PHI can be due to a branch or another loop inside this loop,
9524 // or due to this not being the initial iteration through a loop where we
9525 // couldn't compute the evolution of this particular PHI last time.
9526 if (isa<PHINode>(I)) return nullptr;
9527
9528 std::vector<Constant*> Operands(I->getNumOperands());
9529
9530 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9531 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9532 if (!Operand) {
9533 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9534 if (!Operands[i]) return nullptr;
9535 continue;
9536 }
9537 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9538 Vals[Operand] = C;
9539 if (!C) return nullptr;
9540 Operands[i] = C;
9541 }
9542
9543 return ConstantFoldInstOperands(I, Operands, DL, TLI);
9544}
9545
9546
9547// If every incoming value to PN except the one for BB is a specific Constant,
9548// return that, else return nullptr.
9550 Constant *IncomingVal = nullptr;
9551
9552 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9553 if (PN->getIncomingBlock(i) == BB)
9554 continue;
9555
9556 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9557 if (!CurrentVal)
9558 return nullptr;
9559
9560 if (IncomingVal != CurrentVal) {
9561 if (IncomingVal)
9562 return nullptr;
9563 IncomingVal = CurrentVal;
9564 }
9565 }
9566
9567 return IncomingVal;
9568}
9569
9570/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9571/// in the header of its containing loop, we know the loop executes a
9572/// constant number of times, and the PHI node is just a recurrence
9573/// involving constants, fold it.
9574Constant *
9575ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9576 const APInt &BEs,
9577 const Loop *L) {
9578 auto I = ConstantEvolutionLoopExitValue.find(PN);
9579 if (I != ConstantEvolutionLoopExitValue.end())
9580 return I->second;
9581
9583 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
9584
9585 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
9586
9588 BasicBlock *Header = L->getHeader();
9589 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9590
9591 BasicBlock *Latch = L->getLoopLatch();
9592 if (!Latch)
9593 return nullptr;
9594
9595 for (PHINode &PHI : Header->phis()) {
9596 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9597 CurrentIterVals[&PHI] = StartCST;
9598 }
9599 if (!CurrentIterVals.count(PN))
9600 return RetVal = nullptr;
9601
9602 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9603
9604 // Execute the loop symbolically to determine the exit value.
9605 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9606 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9607
9608 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9609 unsigned IterationNum = 0;
9610 const DataLayout &DL = getDataLayout();
9611 for (; ; ++IterationNum) {
9612 if (IterationNum == NumIterations)
9613 return RetVal = CurrentIterVals[PN]; // Got exit value!
9614
9615 // Compute the value of the PHIs for the next iteration.
9616 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9618 Constant *NextPHI =
9619 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9620 if (!NextPHI)
9621 return nullptr; // Couldn't evaluate!
9622 NextIterVals[PN] = NextPHI;
9623
9624 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9625
9626 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9627 // cease to be able to evaluate one of them or if they stop evolving,
9628 // because that doesn't necessarily prevent us from computing PN.
9630 for (const auto &I : CurrentIterVals) {
9631 PHINode *PHI = dyn_cast<PHINode>(I.first);
9632 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9633 PHIsToCompute.emplace_back(PHI, I.second);
9634 }
9635 // We use two distinct loops because EvaluateExpression may invalidate any
9636 // iterators into CurrentIterVals.
9637 for (const auto &I : PHIsToCompute) {
9638 PHINode *PHI = I.first;
9639 Constant *&NextPHI = NextIterVals[PHI];
9640 if (!NextPHI) { // Not already computed.
9641 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9642 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9643 }
9644 if (NextPHI != I.second)
9645 StoppedEvolving = false;
9646 }
9647
9648 // If all entries in CurrentIterVals == NextIterVals then we can stop
9649 // iterating, the loop can't continue to change.
9650 if (StoppedEvolving)
9651 return RetVal = CurrentIterVals[PN];
9652
9653 CurrentIterVals.swap(NextIterVals);
9654 }
9655}
9656
9657const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9658 Value *Cond,
9659 bool ExitWhen) {
9661 if (!PN) return getCouldNotCompute();
9662
9663 // If the loop is canonicalized, the PHI will have exactly two entries.
9664 // That's the only form we support here.
9665 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9666
9668 BasicBlock *Header = L->getHeader();
9669 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9670
9671 BasicBlock *Latch = L->getLoopLatch();
9672 assert(Latch && "Should follow from NumIncomingValues == 2!");
9673
9674 for (PHINode &PHI : Header->phis()) {
9675 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9676 CurrentIterVals[&PHI] = StartCST;
9677 }
9678 if (!CurrentIterVals.count(PN))
9679 return getCouldNotCompute();
9680
9681 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9682 // the loop symbolically to determine when the condition gets a value of
9683 // "ExitWhen".
9684 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9685 const DataLayout &DL = getDataLayout();
9686 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9687 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9688 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9689
9690 // Couldn't symbolically evaluate.
9691 if (!CondVal) return getCouldNotCompute();
9692
9693 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9694 ++NumBruteForceTripCountsComputed;
9695 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9696 }
9697
9698 // Update all the PHI nodes for the next iteration.
9700
9701 // Create a list of which PHIs we need to compute. We want to do this before
9702 // calling EvaluateExpression on them because that may invalidate iterators
9703 // into CurrentIterVals.
9704 SmallVector<PHINode *, 8> PHIsToCompute;
9705 for (const auto &I : CurrentIterVals) {
9706 PHINode *PHI = dyn_cast<PHINode>(I.first);
9707 if (!PHI || PHI->getParent() != Header) continue;
9708 PHIsToCompute.push_back(PHI);
9709 }
9710 for (PHINode *PHI : PHIsToCompute) {
9711 Constant *&NextPHI = NextIterVals[PHI];
9712 if (NextPHI) continue; // Already computed!
9713
9714 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9715 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9716 }
9717 CurrentIterVals.swap(NextIterVals);
9718 }
9719
9720 // Too many iterations were needed to evaluate.
9721 return getCouldNotCompute();
9722}
9723
9724const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9726 ValuesAtScopes[V];
9727 // Check to see if we've folded this expression at this loop before.
9728 for (auto &LS : Values)
9729 if (LS.first == L)
9730 return LS.second ? LS.second : V;
9731
9732 Values.emplace_back(L, nullptr);
9733
9734 // Otherwise compute it.
9735 const SCEV *C = computeSCEVAtScope(V, L);
9736 for (auto &LS : reverse(ValuesAtScopes[V]))
9737 if (LS.first == L) {
9738 LS.second = C;
9739 if (!isa<SCEVConstant>(C))
9740 ValuesAtScopesUsers[C].push_back({L, V});
9741 break;
9742 }
9743 return C;
9744}
9745
9746/// This builds up a Constant using the ConstantExpr interface. That way, we
9747/// will return Constants for objects which aren't represented by a
9748/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9749/// Returns NULL if the SCEV isn't representable as a Constant.
9751 switch (V->getSCEVType()) {
9752 case scCouldNotCompute:
9753 case scAddRecExpr:
9754 case scVScale:
9755 return nullptr;
9756 case scConstant:
9757 return cast<SCEVConstant>(V)->getValue();
9758 case scUnknown:
9759 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9760 case scPtrToInt: {
9761 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9762 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9763 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9764
9765 return nullptr;
9766 }
9767 case scTruncate: {
9768 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9769 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9770 return ConstantExpr::getTrunc(CastOp, ST->getType());
9771 return nullptr;
9772 }
9773 case scAddExpr: {
9774 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9775 Constant *C = nullptr;
9776 for (const SCEV *Op : SA->operands()) {
9778 if (!OpC)
9779 return nullptr;
9780 if (!C) {
9781 C = OpC;
9782 continue;
9783 }
9784 assert(!C->getType()->isPointerTy() &&
9785 "Can only have one pointer, and it must be last");
9786 if (OpC->getType()->isPointerTy()) {
9787 // The offsets have been converted to bytes. We can add bytes using
9788 // an i8 GEP.
9790 OpC, C);
9791 } else {
9792 C = ConstantExpr::getAdd(C, OpC);
9793 }
9794 }
9795 return C;
9796 }
9797 case scMulExpr:
9798 case scSignExtend:
9799 case scZeroExtend:
9800 case scUDivExpr:
9801 case scSMaxExpr:
9802 case scUMaxExpr:
9803 case scSMinExpr:
9804 case scUMinExpr:
9806 return nullptr;
9807 }
9808 llvm_unreachable("Unknown SCEV kind!");
9809}
9810
9811const SCEV *
9812ScalarEvolution::getWithOperands(const SCEV *S,
9814 switch (S->getSCEVType()) {
9815 case scTruncate:
9816 case scZeroExtend:
9817 case scSignExtend:
9818 case scPtrToInt:
9819 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9820 case scAddRecExpr: {
9821 auto *AddRec = cast<SCEVAddRecExpr>(S);
9822 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9823 }
9824 case scAddExpr:
9825 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9826 case scMulExpr:
9827 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9828 case scUDivExpr:
9829 return getUDivExpr(NewOps[0], NewOps[1]);
9830 case scUMaxExpr:
9831 case scSMaxExpr:
9832 case scUMinExpr:
9833 case scSMinExpr:
9834 return getMinMaxExpr(S->getSCEVType(), NewOps);
9836 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9837 case scConstant:
9838 case scVScale:
9839 case scUnknown:
9840 return S;
9841 case scCouldNotCompute:
9842 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9843 }
9844 llvm_unreachable("Unknown SCEV kind!");
9845}
9846
9847const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9848 switch (V->getSCEVType()) {
9849 case scConstant:
9850 case scVScale:
9851 return V;
9852 case scAddRecExpr: {
9853 // If this is a loop recurrence for a loop that does not contain L, then we
9854 // are dealing with the final value computed by the loop.
9855 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9856 // First, attempt to evaluate each operand.
9857 // Avoid performing the look-up in the common case where the specified
9858 // expression has no loop-variant portions.
9859 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9860 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9861 if (OpAtScope == AddRec->getOperand(i))
9862 continue;
9863
9864 // Okay, at least one of these operands is loop variant but might be
9865 // foldable. Build a new instance of the folded commutative expression.
9867 NewOps.reserve(AddRec->getNumOperands());
9868 append_range(NewOps, AddRec->operands().take_front(i));
9869 NewOps.push_back(OpAtScope);
9870 for (++i; i != e; ++i)
9871 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9872
9873 const SCEV *FoldedRec = getAddRecExpr(
9874 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
9875 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9876 // The addrec may be folded to a nonrecurrence, for example, if the
9877 // induction variable is multiplied by zero after constant folding. Go
9878 // ahead and return the folded value.
9879 if (!AddRec)
9880 return FoldedRec;
9881 break;
9882 }
9883
9884 // If the scope is outside the addrec's loop, evaluate it by using the
9885 // loop exit value of the addrec.
9886 if (!AddRec->getLoop()->contains(L)) {
9887 // To evaluate this recurrence, we need to know how many times the AddRec
9888 // loop iterates. Compute this now.
9889 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9890 if (BackedgeTakenCount == getCouldNotCompute())
9891 return AddRec;
9892
9893 // Then, evaluate the AddRec.
9894 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9895 }
9896
9897 return AddRec;
9898 }
9899 case scTruncate:
9900 case scZeroExtend:
9901 case scSignExtend:
9902 case scPtrToInt:
9903 case scAddExpr:
9904 case scMulExpr:
9905 case scUDivExpr:
9906 case scUMaxExpr:
9907 case scSMaxExpr:
9908 case scUMinExpr:
9909 case scSMinExpr:
9910 case scSequentialUMinExpr: {
9911 ArrayRef<const SCEV *> Ops = V->operands();
9912 // Avoid performing the look-up in the common case where the specified
9913 // expression has no loop-variant portions.
9914 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
9915 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
9916 if (OpAtScope != Ops[i]) {
9917 // Okay, at least one of these operands is loop variant but might be
9918 // foldable. Build a new instance of the folded commutative expression.
9920 NewOps.reserve(Ops.size());
9921 append_range(NewOps, Ops.take_front(i));
9922 NewOps.push_back(OpAtScope);
9923
9924 for (++i; i != e; ++i) {
9925 OpAtScope = getSCEVAtScope(Ops[i], L);
9926 NewOps.push_back(OpAtScope);
9927 }
9928
9929 return getWithOperands(V, NewOps);
9930 }
9931 }
9932 // If we got here, all operands are loop invariant.
9933 return V;
9934 }
9935 case scUnknown: {
9936 // If this instruction is evolved from a constant-evolving PHI, compute the
9937 // exit value from the loop without using SCEVs.
9938 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
9939 Instruction *I = dyn_cast<Instruction>(SU->getValue());
9940 if (!I)
9941 return V; // This is some other type of SCEVUnknown, just return it.
9942
9943 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9944 const Loop *CurrLoop = this->LI[I->getParent()];
9945 // Looking for loop exit value.
9946 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9947 PN->getParent() == CurrLoop->getHeader()) {
9948 // Okay, there is no closed form solution for the PHI node. Check
9949 // to see if the loop that contains it has a known backedge-taken
9950 // count. If so, we may be able to force computation of the exit
9951 // value.
9952 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
9953 // This trivial case can show up in some degenerate cases where
9954 // the incoming IR has not yet been fully simplified.
9955 if (BackedgeTakenCount->isZero()) {
9956 Value *InitValue = nullptr;
9957 bool MultipleInitValues = false;
9958 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
9959 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
9960 if (!InitValue)
9961 InitValue = PN->getIncomingValue(i);
9962 else if (InitValue != PN->getIncomingValue(i)) {
9963 MultipleInitValues = true;
9964 break;
9965 }
9966 }
9967 }
9968 if (!MultipleInitValues && InitValue)
9969 return getSCEV(InitValue);
9970 }
9971 // Do we have a loop invariant value flowing around the backedge
9972 // for a loop which must execute the backedge?
9973 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
9974 isKnownNonZero(BackedgeTakenCount) &&
9975 PN->getNumIncomingValues() == 2) {
9976
9977 unsigned InLoopPred =
9978 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
9979 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
9980 if (CurrLoop->isLoopInvariant(BackedgeVal))
9981 return getSCEV(BackedgeVal);
9982 }
9983 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
9984 // Okay, we know how many times the containing loop executes. If
9985 // this is a constant evolving PHI node, get the final value at
9986 // the specified iteration number.
9987 Constant *RV =
9988 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
9989 if (RV)
9990 return getSCEV(RV);
9991 }
9992 }
9993 }
9994
9995 // Okay, this is an expression that we cannot symbolically evaluate
9996 // into a SCEV. Check to see if it's possible to symbolically evaluate
9997 // the arguments into constants, and if so, try to constant propagate the
9998 // result. This is particularly useful for computing loop exit values.
9999 if (!CanConstantFold(I))
10000 return V; // This is some other type of SCEVUnknown, just return it.
10001
10003 Operands.reserve(I->getNumOperands());
10004 bool MadeImprovement = false;
10005 for (Value *Op : I->operands()) {
10006 if (Constant *C = dyn_cast<Constant>(Op)) {
10007 Operands.push_back(C);
10008 continue;
10009 }
10010
10011 // If any of the operands is non-constant and if they are
10012 // non-integer and non-pointer, don't even try to analyze them
10013 // with scev techniques.
10014 if (!isSCEVable(Op->getType()))
10015 return V;
10016
10017 const SCEV *OrigV = getSCEV(Op);
10018 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10019 MadeImprovement |= OrigV != OpV;
10020
10022 if (!C)
10023 return V;
10024 assert(C->getType() == Op->getType() && "Type mismatch");
10025 Operands.push_back(C);
10026 }
10027
10028 // Check to see if getSCEVAtScope actually made an improvement.
10029 if (!MadeImprovement)
10030 return V; // This is some other type of SCEVUnknown, just return it.
10031
10032 Constant *C = nullptr;
10033 const DataLayout &DL = getDataLayout();
10035 if (!C)
10036 return V;
10037 return getSCEV(C);
10038 }
10039 case scCouldNotCompute:
10040 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10041 }
10042 llvm_unreachable("Unknown SCEV type!");
10043}
10044
10046 return getSCEVAtScope(getSCEV(V), L);
10047}
10048
10049const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10050 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10051 return stripInjectiveFunctions(ZExt->getOperand());
10052 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10053 return stripInjectiveFunctions(SExt->getOperand());
10054 return S;
10055}
10056
10057/// Finds the minimum unsigned root of the following equation:
10058///
10059/// A * X = B (mod N)
10060///
10061/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10062/// A and B isn't important.
10063///
10064/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10065static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
10066 ScalarEvolution &SE) {
10067 uint32_t BW = A.getBitWidth();
10068 assert(BW == SE.getTypeSizeInBits(B->getType()));
10069 assert(A != 0 && "A must be non-zero.");
10070
10071 // 1. D = gcd(A, N)
10072 //
10073 // The gcd of A and N may have only one prime factor: 2. The number of
10074 // trailing zeros in A is its multiplicity
10075 uint32_t Mult2 = A.countr_zero();
10076 // D = 2^Mult2
10077
10078 // 2. Check if B is divisible by D.
10079 //
10080 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10081 // is not less than multiplicity of this prime factor for D.
10082 if (SE.getMinTrailingZeros(B) < Mult2)
10083 return SE.getCouldNotCompute();
10084
10085 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10086 // modulo (N / D).
10087 //
10088 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10089 // (N / D) in general. The inverse itself always fits into BW bits, though,
10090 // so we immediately truncate it.
10091 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10092 APInt I = AD.multiplicativeInverse().zext(BW);
10093
10094 // 4. Compute the minimum unsigned root of the equation:
10095 // I * (B / D) mod (N / D)
10096 // To simplify the computation, we factor out the divide by D:
10097 // (I * B mod N) / D
10098 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10099 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10100}
10101
10102/// For a given quadratic addrec, generate coefficients of the corresponding
10103/// quadratic equation, multiplied by a common value to ensure that they are
10104/// integers.
10105/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10106/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10107/// were multiplied by, and BitWidth is the bit width of the original addrec
10108/// coefficients.
10109/// This function returns std::nullopt if the addrec coefficients are not
10110/// compile- time constants.
10111static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10113 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10114 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10115 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10116 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10117 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10118 << *AddRec << '\n');
10119
10120 // We currently can only solve this if the coefficients are constants.
10121 if (!LC || !MC || !NC) {
10122 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10123 return std::nullopt;
10124 }
10125
10126 APInt L = LC->getAPInt();
10127 APInt M = MC->getAPInt();
10128 APInt N = NC->getAPInt();
10129 assert(!N.isZero() && "This is not a quadratic addrec");
10130
10131 unsigned BitWidth = LC->getAPInt().getBitWidth();
10132 unsigned NewWidth = BitWidth + 1;
10133 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10134 << BitWidth << '\n');
10135 // The sign-extension (as opposed to a zero-extension) here matches the
10136 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10137 N = N.sext(NewWidth);
10138 M = M.sext(NewWidth);
10139 L = L.sext(NewWidth);
10140
10141 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10142 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10143 // L+M, L+2M+N, L+3M+3N, ...
10144 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10145 //
10146 // The equation Acc = 0 is then
10147 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10148 // In a quadratic form it becomes:
10149 // N n^2 + (2M-N) n + 2L = 0.
10150
10151 APInt A = N;
10152 APInt B = 2 * M - A;
10153 APInt C = 2 * L;
10154 APInt T = APInt(NewWidth, 2);
10155 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10156 << "x + " << C << ", coeff bw: " << NewWidth
10157 << ", multiplied by " << T << '\n');
10158 return std::make_tuple(A, B, C, T, BitWidth);
10159}
10160
10161/// Helper function to compare optional APInts:
10162/// (a) if X and Y both exist, return min(X, Y),
10163/// (b) if neither X nor Y exist, return std::nullopt,
10164/// (c) if exactly one of X and Y exists, return that value.
10165static std::optional<APInt> MinOptional(std::optional<APInt> X,
10166 std::optional<APInt> Y) {
10167 if (X && Y) {
10168 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10169 APInt XW = X->sext(W);
10170 APInt YW = Y->sext(W);
10171 return XW.slt(YW) ? *X : *Y;
10172 }
10173 if (!X && !Y)
10174 return std::nullopt;
10175 return X ? *X : *Y;
10176}
10177
10178/// Helper function to truncate an optional APInt to a given BitWidth.
10179/// When solving addrec-related equations, it is preferable to return a value
10180/// that has the same bit width as the original addrec's coefficients. If the
10181/// solution fits in the original bit width, truncate it (except for i1).
10182/// Returning a value of a different bit width may inhibit some optimizations.
10183///
10184/// In general, a solution to a quadratic equation generated from an addrec
10185/// may require BW+1 bits, where BW is the bit width of the addrec's
10186/// coefficients. The reason is that the coefficients of the quadratic
10187/// equation are BW+1 bits wide (to avoid truncation when converting from
10188/// the addrec to the equation).
10189static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10190 unsigned BitWidth) {
10191 if (!X)
10192 return std::nullopt;
10193 unsigned W = X->getBitWidth();
10194 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10195 return X->trunc(BitWidth);
10196 return X;
10197}
10198
10199/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10200/// iterations. The values L, M, N are assumed to be signed, and they
10201/// should all have the same bit widths.
10202/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10203/// where BW is the bit width of the addrec's coefficients.
10204/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10205/// returned as such, otherwise the bit width of the returned value may
10206/// be greater than BW.
10207///
10208/// This function returns std::nullopt if
10209/// (a) the addrec coefficients are not constant, or
10210/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10211/// like x^2 = 5, no integer solutions exist, in other cases an integer
10212/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10213static std::optional<APInt>
10215 APInt A, B, C, M;
10216 unsigned BitWidth;
10217 auto T = GetQuadraticEquation(AddRec);
10218 if (!T)
10219 return std::nullopt;
10220
10221 std::tie(A, B, C, M, BitWidth) = *T;
10222 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10223 std::optional<APInt> X =
10225 if (!X)
10226 return std::nullopt;
10227
10228 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10229 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10230 if (!V->isZero())
10231 return std::nullopt;
10232
10233 return TruncIfPossible(X, BitWidth);
10234}
10235
10236/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10237/// iterations. The values M, N are assumed to be signed, and they
10238/// should all have the same bit widths.
10239/// Find the least n such that c(n) does not belong to the given range,
10240/// while c(n-1) does.
10241///
10242/// This function returns std::nullopt if
10243/// (a) the addrec coefficients are not constant, or
10244/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10245/// bounds of the range.
10246static std::optional<APInt>
10248 const ConstantRange &Range, ScalarEvolution &SE) {
10249 assert(AddRec->getOperand(0)->isZero() &&
10250 "Starting value of addrec should be 0");
10251 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10252 << Range << ", addrec " << *AddRec << '\n');
10253 // This case is handled in getNumIterationsInRange. Here we can assume that
10254 // we start in the range.
10255 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10256 "Addrec's initial value should be in range");
10257
10258 APInt A, B, C, M;
10259 unsigned BitWidth;
10260 auto T = GetQuadraticEquation(AddRec);
10261 if (!T)
10262 return std::nullopt;
10263
10264 // Be careful about the return value: there can be two reasons for not
10265 // returning an actual number. First, if no solutions to the equations
10266 // were found, and second, if the solutions don't leave the given range.
10267 // The first case means that the actual solution is "unknown", the second
10268 // means that it's known, but not valid. If the solution is unknown, we
10269 // cannot make any conclusions.
10270 // Return a pair: the optional solution and a flag indicating if the
10271 // solution was found.
10272 auto SolveForBoundary =
10273 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10274 // Solve for signed overflow and unsigned overflow, pick the lower
10275 // solution.
10276 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10277 << Bound << " (before multiplying by " << M << ")\n");
10278 Bound *= M; // The quadratic equation multiplier.
10279
10280 std::optional<APInt> SO;
10281 if (BitWidth > 1) {
10282 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10283 "signed overflow\n");
10285 }
10286 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10287 "unsigned overflow\n");
10288 std::optional<APInt> UO =
10290
10291 auto LeavesRange = [&] (const APInt &X) {
10292 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10293 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10294 if (Range.contains(V0->getValue()))
10295 return false;
10296 // X should be at least 1, so X-1 is non-negative.
10297 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10298 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10299 if (Range.contains(V1->getValue()))
10300 return true;
10301 return false;
10302 };
10303
10304 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10305 // can be a solution, but the function failed to find it. We cannot treat it
10306 // as "no solution".
10307 if (!SO || !UO)
10308 return {std::nullopt, false};
10309
10310 // Check the smaller value first to see if it leaves the range.
10311 // At this point, both SO and UO must have values.
10312 std::optional<APInt> Min = MinOptional(SO, UO);
10313 if (LeavesRange(*Min))
10314 return { Min, true };
10315 std::optional<APInt> Max = Min == SO ? UO : SO;
10316 if (LeavesRange(*Max))
10317 return { Max, true };
10318
10319 // Solutions were found, but were eliminated, hence the "true".
10320 return {std::nullopt, true};
10321 };
10322
10323 std::tie(A, B, C, M, BitWidth) = *T;
10324 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10325 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10326 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10327 auto SL = SolveForBoundary(Lower);
10328 auto SU = SolveForBoundary(Upper);
10329 // If any of the solutions was unknown, no meaninigful conclusions can
10330 // be made.
10331 if (!SL.second || !SU.second)
10332 return std::nullopt;
10333
10334 // Claim: The correct solution is not some value between Min and Max.
10335 //
10336 // Justification: Assuming that Min and Max are different values, one of
10337 // them is when the first signed overflow happens, the other is when the
10338 // first unsigned overflow happens. Crossing the range boundary is only
10339 // possible via an overflow (treating 0 as a special case of it, modeling
10340 // an overflow as crossing k*2^W for some k).
10341 //
10342 // The interesting case here is when Min was eliminated as an invalid
10343 // solution, but Max was not. The argument is that if there was another
10344 // overflow between Min and Max, it would also have been eliminated if
10345 // it was considered.
10346 //
10347 // For a given boundary, it is possible to have two overflows of the same
10348 // type (signed/unsigned) without having the other type in between: this
10349 // can happen when the vertex of the parabola is between the iterations
10350 // corresponding to the overflows. This is only possible when the two
10351 // overflows cross k*2^W for the same k. In such case, if the second one
10352 // left the range (and was the first one to do so), the first overflow
10353 // would have to enter the range, which would mean that either we had left
10354 // the range before or that we started outside of it. Both of these cases
10355 // are contradictions.
10356 //
10357 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10358 // solution is not some value between the Max for this boundary and the
10359 // Min of the other boundary.
10360 //
10361 // Justification: Assume that we had such Max_A and Min_B corresponding
10362 // to range boundaries A and B and such that Max_A < Min_B. If there was
10363 // a solution between Max_A and Min_B, it would have to be caused by an
10364 // overflow corresponding to either A or B. It cannot correspond to B,
10365 // since Min_B is the first occurrence of such an overflow. If it
10366 // corresponded to A, it would have to be either a signed or an unsigned
10367 // overflow that is larger than both eliminated overflows for A. But
10368 // between the eliminated overflows and this overflow, the values would
10369 // cover the entire value space, thus crossing the other boundary, which
10370 // is a contradiction.
10371
10372 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10373}
10374
10375ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10376 const Loop *L,
10377 bool ControlsOnlyExit,
10378 bool AllowPredicates) {
10379
10380 // This is only used for loops with a "x != y" exit test. The exit condition
10381 // is now expressed as a single expression, V = x-y. So the exit test is
10382 // effectively V != 0. We know and take advantage of the fact that this
10383 // expression only being used in a comparison by zero context.
10384
10386 // If the value is a constant
10387 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10388 // If the value is already zero, the branch will execute zero times.
10389 if (C->getValue()->isZero()) return C;
10390 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10391 }
10392
10393 const SCEVAddRecExpr *AddRec =
10394 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10395
10396 if (!AddRec && AllowPredicates)
10397 // Try to make this an AddRec using runtime tests, in the first X
10398 // iterations of this loop, where X is the SCEV expression found by the
10399 // algorithm below.
10400 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10401
10402 if (!AddRec || AddRec->getLoop() != L)
10403 return getCouldNotCompute();
10404
10405 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10406 // the quadratic equation to solve it.
10407 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10408 // We can only use this value if the chrec ends up with an exact zero
10409 // value at this index. When solving for "X*X != 5", for example, we
10410 // should not accept a root of 2.
10411 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10412 const auto *R = cast<SCEVConstant>(getConstant(*S));
10413 return ExitLimit(R, R, R, false, Predicates);
10414 }
10415 return getCouldNotCompute();
10416 }
10417
10418 // Otherwise we can only handle this if it is affine.
10419 if (!AddRec->isAffine())
10420 return getCouldNotCompute();
10421
10422 // If this is an affine expression, the execution count of this branch is
10423 // the minimum unsigned root of the following equation:
10424 //
10425 // Start + Step*N = 0 (mod 2^BW)
10426 //
10427 // equivalent to:
10428 //
10429 // Step*N = -Start (mod 2^BW)
10430 //
10431 // where BW is the common bit width of Start and Step.
10432
10433 // Get the initial value for the loop.
10434 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10435 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10436
10437 // For now we handle only constant steps.
10438 //
10439 // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
10440 // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
10441 // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
10442 // We have not yet seen any such cases.
10443 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10444 if (!StepC || StepC->getValue()->isZero())
10445 return getCouldNotCompute();
10446
10447 // For positive steps (counting up until unsigned overflow):
10448 // N = -Start/Step (as unsigned)
10449 // For negative steps (counting down to zero):
10450 // N = Start/-Step
10451 // First compute the unsigned distance from zero in the direction of Step.
10452 bool CountDown = StepC->getAPInt().isNegative();
10453 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10454
10455 // Handle unitary steps, which cannot wraparound.
10456 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10457 // N = Distance (as unsigned)
10458 if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
10459 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
10460 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10461
10462 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10463 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10464 // case, and see if we can improve the bound.
10465 //
10466 // Explicitly handling this here is necessary because getUnsignedRange
10467 // isn't context-sensitive; it doesn't know that we only care about the
10468 // range inside the loop.
10469 const SCEV *Zero = getZero(Distance->getType());
10470 const SCEV *One = getOne(Distance->getType());
10471 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10472 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10473 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10474 // as "unsigned_max(Distance + 1) - 1".
10475 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10476 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10477 }
10478 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10479 Predicates);
10480 }
10481
10482 // If the condition controls loop exit (the loop exits only if the expression
10483 // is true) and the addition is no-wrap we can use unsigned divide to
10484 // compute the backedge count. In this case, the step may not divide the
10485 // distance, but we don't care because if the condition is "missed" the loop
10486 // will have undefined behavior due to wrapping.
10487 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10488 loopHasNoAbnormalExits(AddRec->getLoop())) {
10489 const SCEV *Exact =
10490 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10491 const SCEV *ConstantMax = getCouldNotCompute();
10492 if (Exact != getCouldNotCompute()) {
10494 ConstantMax =
10496 }
10497 const SCEV *SymbolicMax =
10498 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10499 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10500 }
10501
10502 // Solve the general equation.
10503 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
10504 getNegativeSCEV(Start), *this);
10505
10506 const SCEV *M = E;
10507 if (E != getCouldNotCompute()) {
10508 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
10509 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10510 }
10511 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10512 return ExitLimit(E, M, S, false, Predicates);
10513}
10514
10516ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10517 // Loops that look like: while (X == 0) are very strange indeed. We don't
10518 // handle them yet except for the trivial case. This could be expanded in the
10519 // future as needed.
10520
10521 // If the value is a constant, check to see if it is known to be non-zero
10522 // already. If so, the backedge will execute zero times.
10523 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10524 if (!C->getValue()->isZero())
10525 return getZero(C->getType());
10526 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10527 }
10528
10529 // We could implement others, but I really doubt anyone writes loops like
10530 // this, and if they did, they would already be constant folded.
10531 return getCouldNotCompute();
10532}
10533
10534std::pair<const BasicBlock *, const BasicBlock *>
10535ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10536 const {
10537 // If the block has a unique predecessor, then there is no path from the
10538 // predecessor to the block that does not go through the direct edge
10539 // from the predecessor to the block.
10540 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10541 return {Pred, BB};
10542
10543 // A loop's header is defined to be a block that dominates the loop.
10544 // If the header has a unique predecessor outside the loop, it must be
10545 // a block that has exactly one successor that can reach the loop.
10546 if (const Loop *L = LI.getLoopFor(BB))
10547 return {L->getLoopPredecessor(), L->getHeader()};
10548
10549 return {nullptr, nullptr};
10550}
10551
10552/// SCEV structural equivalence is usually sufficient for testing whether two
10553/// expressions are equal, however for the purposes of looking for a condition
10554/// guarding a loop, it can be useful to be a little more general, since a
10555/// front-end may have replicated the controlling expression.
10556static bool HasSameValue(const SCEV *A, const SCEV *B) {
10557 // Quick check to see if they are the same SCEV.
10558 if (A == B) return true;
10559
10560 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10561 // Not all instructions that are "identical" compute the same value. For
10562 // instance, two distinct alloca instructions allocating the same type are
10563 // identical and do not read memory; but compute distinct values.
10564 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10565 };
10566
10567 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10568 // two different instructions with the same value. Check for this case.
10569 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10570 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10571 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10572 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10573 if (ComputesEqualValues(AI, BI))
10574 return true;
10575
10576 // Otherwise assume they may have a different value.
10577 return false;
10578}
10579
10580static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10581 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10582 if (!Add || Add->getNumOperands() != 2)
10583 return false;
10584 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10585 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10586 LHS = Add->getOperand(1);
10587 RHS = ME->getOperand(1);
10588 return true;
10589 }
10590 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10591 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10592 LHS = Add->getOperand(0);
10593 RHS = ME->getOperand(1);
10594 return true;
10595 }
10596 return false;
10597}
10598
10600 const SCEV *&LHS, const SCEV *&RHS,
10601 unsigned Depth) {
10602 bool Changed = false;
10603 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10604 // '0 != 0'.
10605 auto TrivialCase = [&](bool TriviallyTrue) {
10607 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10608 return true;
10609 };
10610 // If we hit the max recursion limit bail out.
10611 if (Depth >= 3)
10612 return false;
10613
10614 // Canonicalize a constant to the right side.
10615 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10616 // Check for both operands constant.
10617 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10618 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10619 return TrivialCase(false);
10620 return TrivialCase(true);
10621 }
10622 // Otherwise swap the operands to put the constant on the right.
10623 std::swap(LHS, RHS);
10624 Pred = ICmpInst::getSwappedPredicate(Pred);
10625 Changed = true;
10626 }
10627
10628 // If we're comparing an addrec with a value which is loop-invariant in the
10629 // addrec's loop, put the addrec on the left. Also make a dominance check,
10630 // as both operands could be addrecs loop-invariant in each other's loop.
10631 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10632 const Loop *L = AR->getLoop();
10633 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10634 std::swap(LHS, RHS);
10635 Pred = ICmpInst::getSwappedPredicate(Pred);
10636 Changed = true;
10637 }
10638 }
10639
10640 // If there's a constant operand, canonicalize comparisons with boundary
10641 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10642 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10643 const APInt &RA = RC->getAPInt();
10644
10645 bool SimplifiedByConstantRange = false;
10646
10647 if (!ICmpInst::isEquality(Pred)) {
10649 if (ExactCR.isFullSet())
10650 return TrivialCase(true);
10651 if (ExactCR.isEmptySet())
10652 return TrivialCase(false);
10653
10654 APInt NewRHS;
10655 CmpInst::Predicate NewPred;
10656 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10657 ICmpInst::isEquality(NewPred)) {
10658 // We were able to convert an inequality to an equality.
10659 Pred = NewPred;
10660 RHS = getConstant(NewRHS);
10661 Changed = SimplifiedByConstantRange = true;
10662 }
10663 }
10664
10665 if (!SimplifiedByConstantRange) {
10666 switch (Pred) {
10667 default:
10668 break;
10669 case ICmpInst::ICMP_EQ:
10670 case ICmpInst::ICMP_NE:
10671 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10672 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10673 Changed = true;
10674 break;
10675
10676 // The "Should have been caught earlier!" messages refer to the fact
10677 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10678 // should have fired on the corresponding cases, and canonicalized the
10679 // check to trivial case.
10680
10681 case ICmpInst::ICMP_UGE:
10682 assert(!RA.isMinValue() && "Should have been caught earlier!");
10683 Pred = ICmpInst::ICMP_UGT;
10684 RHS = getConstant(RA - 1);
10685 Changed = true;
10686 break;
10687 case ICmpInst::ICMP_ULE:
10688 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10689 Pred = ICmpInst::ICMP_ULT;
10690 RHS = getConstant(RA + 1);
10691 Changed = true;
10692 break;
10693 case ICmpInst::ICMP_SGE:
10694 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10695 Pred = ICmpInst::ICMP_SGT;
10696 RHS = getConstant(RA - 1);
10697 Changed = true;
10698 break;
10699 case ICmpInst::ICMP_SLE:
10700 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10701 Pred = ICmpInst::ICMP_SLT;
10702 RHS = getConstant(RA + 1);
10703 Changed = true;
10704 break;
10705 }
10706 }
10707 }
10708
10709 // Check for obvious equality.
10710 if (HasSameValue(LHS, RHS)) {
10711 if (ICmpInst::isTrueWhenEqual(Pred))
10712 return TrivialCase(true);
10714 return TrivialCase(false);
10715 }
10716
10717 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10718 // adding or subtracting 1 from one of the operands.
10719 switch (Pred) {
10720 case ICmpInst::ICMP_SLE:
10721 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10722 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10724 Pred = ICmpInst::ICMP_SLT;
10725 Changed = true;
10726 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10727 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10729 Pred = ICmpInst::ICMP_SLT;
10730 Changed = true;
10731 }
10732 break;
10733 case ICmpInst::ICMP_SGE:
10734 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10735 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10737 Pred = ICmpInst::ICMP_SGT;
10738 Changed = true;
10739 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10740 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10742 Pred = ICmpInst::ICMP_SGT;
10743 Changed = true;
10744 }
10745 break;
10746 case ICmpInst::ICMP_ULE:
10747 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10748 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10750 Pred = ICmpInst::ICMP_ULT;
10751 Changed = true;
10752 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10753 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10754 Pred = ICmpInst::ICMP_ULT;
10755 Changed = true;
10756 }
10757 break;
10758 case ICmpInst::ICMP_UGE:
10759 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10760 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10761 Pred = ICmpInst::ICMP_UGT;
10762 Changed = true;
10763 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10764 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10766 Pred = ICmpInst::ICMP_UGT;
10767 Changed = true;
10768 }
10769 break;
10770 default:
10771 break;
10772 }
10773
10774 // TODO: More simplifications are possible here.
10775
10776 // Recursively simplify until we either hit a recursion limit or nothing
10777 // changes.
10778 if (Changed)
10779 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10780
10781 return Changed;
10782}
10783
10785 return getSignedRangeMax(S).isNegative();
10786}
10787
10790}
10791
10793 return !getSignedRangeMin(S).isNegative();
10794}
10795
10798}
10799
10801 // Query push down for cases where the unsigned range is
10802 // less than sufficient.
10803 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10804 return isKnownNonZero(SExt->getOperand(0));
10805 return getUnsignedRangeMin(S) != 0;
10806}
10807
10808std::pair<const SCEV *, const SCEV *>
10810 // Compute SCEV on entry of loop L.
10811 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10812 if (Start == getCouldNotCompute())
10813 return { Start, Start };
10814 // Compute post increment SCEV for loop L.
10815 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10816 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10817 return { Start, PostInc };
10818}
10819
10821 const SCEV *LHS, const SCEV *RHS) {
10822 // First collect all loops.
10824 getUsedLoops(LHS, LoopsUsed);
10825 getUsedLoops(RHS, LoopsUsed);
10826
10827 if (LoopsUsed.empty())
10828 return false;
10829
10830 // Domination relationship must be a linear order on collected loops.
10831#ifndef NDEBUG
10832 for (const auto *L1 : LoopsUsed)
10833 for (const auto *L2 : LoopsUsed)
10834 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
10835 DT.dominates(L2->getHeader(), L1->getHeader())) &&
10836 "Domination relationship is not a linear order");
10837#endif
10838
10839 const Loop *MDL =
10840 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
10841 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10842 });
10843
10844 // Get init and post increment value for LHS.
10845 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10846 // if LHS contains unknown non-invariant SCEV then bail out.
10847 if (SplitLHS.first == getCouldNotCompute())
10848 return false;
10849 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
10850 // Get init and post increment value for RHS.
10851 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10852 // if RHS contains unknown non-invariant SCEV then bail out.
10853 if (SplitRHS.first == getCouldNotCompute())
10854 return false;
10855 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
10856 // It is possible that init SCEV contains an invariant load but it does
10857 // not dominate MDL and is not available at MDL loop entry, so we should
10858 // check it here.
10859 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10860 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10861 return false;
10862
10863 // It seems backedge guard check is faster than entry one so in some cases
10864 // it can speed up whole estimation by short circuit
10865 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10866 SplitRHS.second) &&
10867 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10868}
10869
10871 const SCEV *LHS, const SCEV *RHS) {
10872 // Canonicalize the inputs first.
10873 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10874
10875 if (isKnownViaInduction(Pred, LHS, RHS))
10876 return true;
10877
10878 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10879 return true;
10880
10881 // Otherwise see what can be done with some simple reasoning.
10882 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10883}
10884
10886 const SCEV *LHS,
10887 const SCEV *RHS) {
10888 if (isKnownPredicate(Pred, LHS, RHS))
10889 return true;
10891 return false;
10892 return std::nullopt;
10893}
10894
10896 const SCEV *LHS, const SCEV *RHS,
10897 const Instruction *CtxI) {
10898 // TODO: Analyze guards and assumes from Context's block.
10899 return isKnownPredicate(Pred, LHS, RHS) ||
10901}
10902
10903std::optional<bool>
10905 const SCEV *RHS, const Instruction *CtxI) {
10906 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10907 if (KnownWithoutContext)
10908 return KnownWithoutContext;
10909
10910 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10911 return true;
10914 LHS, RHS))
10915 return false;
10916 return std::nullopt;
10917}
10918
10920 const SCEVAddRecExpr *LHS,
10921 const SCEV *RHS) {
10922 const Loop *L = LHS->getLoop();
10923 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10924 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10925}
10926
10927std::optional<ScalarEvolution::MonotonicPredicateType>
10929 ICmpInst::Predicate Pred) {
10930 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
10931
10932#ifndef NDEBUG
10933 // Verify an invariant: inverting the predicate should turn a monotonically
10934 // increasing change to a monotonically decreasing one, and vice versa.
10935 if (Result) {
10936 auto ResultSwapped =
10937 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
10938
10939 assert(*ResultSwapped != *Result &&
10940 "monotonicity should flip as we flip the predicate");
10941 }
10942#endif
10943
10944 return Result;
10945}
10946
10947std::optional<ScalarEvolution::MonotonicPredicateType>
10948ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
10949 ICmpInst::Predicate Pred) {
10950 // A zero step value for LHS means the induction variable is essentially a
10951 // loop invariant value. We don't really depend on the predicate actually
10952 // flipping from false to true (for increasing predicates, and the other way
10953 // around for decreasing predicates), all we care about is that *if* the
10954 // predicate changes then it only changes from false to true.
10955 //
10956 // A zero step value in itself is not very useful, but there may be places
10957 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
10958 // as general as possible.
10959
10960 // Only handle LE/LT/GE/GT predicates.
10961 if (!ICmpInst::isRelational(Pred))
10962 return std::nullopt;
10963
10964 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
10965 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
10966 "Should be greater or less!");
10967
10968 // Check that AR does not wrap.
10969 if (ICmpInst::isUnsigned(Pred)) {
10970 if (!LHS->hasNoUnsignedWrap())
10971 return std::nullopt;
10973 }
10974 assert(ICmpInst::isSigned(Pred) &&
10975 "Relational predicate is either signed or unsigned!");
10976 if (!LHS->hasNoSignedWrap())
10977 return std::nullopt;
10978
10979 const SCEV *Step = LHS->getStepRecurrence(*this);
10980
10981 if (isKnownNonNegative(Step))
10983
10984 if (isKnownNonPositive(Step))
10986
10987 return std::nullopt;
10988}
10989
10990std::optional<ScalarEvolution::LoopInvariantPredicate>
10992 const SCEV *LHS, const SCEV *RHS,
10993 const Loop *L,
10994 const Instruction *CtxI) {
10995 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10996 if (!isLoopInvariant(RHS, L)) {
10997 if (!isLoopInvariant(LHS, L))
10998 return std::nullopt;
10999
11000 std::swap(LHS, RHS);
11001 Pred = ICmpInst::getSwappedPredicate(Pred);
11002 }
11003
11004 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11005 if (!ArLHS || ArLHS->getLoop() != L)
11006 return std::nullopt;
11007
11008 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11009 if (!MonotonicType)
11010 return std::nullopt;
11011 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11012 // true as the loop iterates, and the backedge is control dependent on
11013 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11014 //
11015 // * if the predicate was false in the first iteration then the predicate
11016 // is never evaluated again, since the loop exits without taking the
11017 // backedge.
11018 // * if the predicate was true in the first iteration then it will
11019 // continue to be true for all future iterations since it is
11020 // monotonically increasing.
11021 //
11022 // For both the above possibilities, we can replace the loop varying
11023 // predicate with its value on the first iteration of the loop (which is
11024 // loop invariant).
11025 //
11026 // A similar reasoning applies for a monotonically decreasing predicate, by
11027 // replacing true with false and false with true in the above two bullets.
11028 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
11029 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
11030
11033 RHS);
11034
11035 if (!CtxI)
11036 return std::nullopt;
11037 // Try to prove via context.
11038 // TODO: Support other cases.
11039 switch (Pred) {
11040 default:
11041 break;
11042 case ICmpInst::ICMP_ULE:
11043 case ICmpInst::ICMP_ULT: {
11044 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11045 // Given preconditions
11046 // (1) ArLHS does not cross the border of positive and negative parts of
11047 // range because of:
11048 // - Positive step; (TODO: lift this limitation)
11049 // - nuw - does not cross zero boundary;
11050 // - nsw - does not cross SINT_MAX boundary;
11051 // (2) ArLHS <s RHS
11052 // (3) RHS >=s 0
11053 // we can replace the loop variant ArLHS <u RHS condition with loop
11054 // invariant Start(ArLHS) <u RHS.
11055 //
11056 // Because of (1) there are two options:
11057 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11058 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11059 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11060 // Because of (2) ArLHS <u RHS is trivially true.
11061 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11062 // We can strengthen this to Start(ArLHS) <u RHS.
11063 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11064 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11065 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11067 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11069 RHS);
11070 }
11071 }
11072
11073 return std::nullopt;
11074}
11075
11076std::optional<ScalarEvolution::LoopInvariantPredicate>
11078 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11079 const Instruction *CtxI, const SCEV *MaxIter) {
11081 Pred, LHS, RHS, L, CtxI, MaxIter))
11082 return LIP;
11083 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11084 // Number of iterations expressed as UMIN isn't always great for expressing
11085 // the value on the last iteration. If the straightforward approach didn't
11086 // work, try the following trick: if the a predicate is invariant for X, it
11087 // is also invariant for umin(X, ...). So try to find something that works
11088 // among subexpressions of MaxIter expressed as umin.
11089 for (auto *Op : UMin->operands())
11091 Pred, LHS, RHS, L, CtxI, Op))
11092 return LIP;
11093 return std::nullopt;
11094}
11095
11096std::optional<ScalarEvolution::LoopInvariantPredicate>
11098 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11099 const Instruction *CtxI, const SCEV *MaxIter) {
11100 // Try to prove the following set of facts:
11101 // - The predicate is monotonic in the iteration space.
11102 // - If the check does not fail on the 1st iteration:
11103 // - No overflow will happen during first MaxIter iterations;
11104 // - It will not fail on the MaxIter'th iteration.
11105 // If the check does fail on the 1st iteration, we leave the loop and no
11106 // other checks matter.
11107
11108 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11109 if (!isLoopInvariant(RHS, L)) {
11110 if (!isLoopInvariant(LHS, L))
11111 return std::nullopt;
11112
11113 std::swap(LHS, RHS);
11114 Pred = ICmpInst::getSwappedPredicate(Pred);
11115 }
11116
11117 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11118 if (!AR || AR->getLoop() != L)
11119 return std::nullopt;
11120
11121 // The predicate must be relational (i.e. <, <=, >=, >).
11122 if (!ICmpInst::isRelational(Pred))
11123 return std::nullopt;
11124
11125 // TODO: Support steps other than +/- 1.
11126 const SCEV *Step = AR->getStepRecurrence(*this);
11127 auto *One = getOne(Step->getType());
11128 auto *MinusOne = getNegativeSCEV(One);
11129 if (Step != One && Step != MinusOne)
11130 return std::nullopt;
11131
11132 // Type mismatch here means that MaxIter is potentially larger than max
11133 // unsigned value in start type, which mean we cannot prove no wrap for the
11134 // indvar.
11135 if (AR->getType() != MaxIter->getType())
11136 return std::nullopt;
11137
11138 // Value of IV on suggested last iteration.
11139 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11140 // Does it still meet the requirement?
11141 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11142 return std::nullopt;
11143 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11144 // not exceed max unsigned value of this type), this effectively proves
11145 // that there is no wrap during the iteration. To prove that there is no
11146 // signed/unsigned wrap, we need to check that
11147 // Start <= Last for step = 1 or Start >= Last for step = -1.
11148 ICmpInst::Predicate NoOverflowPred =
11150 if (Step == MinusOne)
11151 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
11152 const SCEV *Start = AR->getStart();
11153 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11154 return std::nullopt;
11155
11156 // Everything is fine.
11157 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11158}
11159
11160bool ScalarEvolution::isKnownPredicateViaConstantRanges(
11161 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
11162 if (HasSameValue(LHS, RHS))
11163 return ICmpInst::isTrueWhenEqual(Pred);
11164
11165 // This code is split out from isKnownPredicate because it is called from
11166 // within isLoopEntryGuardedByCond.
11167
11168 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11169 const ConstantRange &RangeRHS) {
11170 return RangeLHS.icmp(Pred, RangeRHS);
11171 };
11172
11173 // The check at the top of the function catches the case where the values are
11174 // known to be equal.
11175 if (Pred == CmpInst::ICMP_EQ)
11176 return false;
11177
11178 if (Pred == CmpInst::ICMP_NE) {
11179 auto SL = getSignedRange(LHS);
11180 auto SR = getSignedRange(RHS);
11181 if (CheckRanges(SL, SR))
11182 return true;
11183 auto UL = getUnsignedRange(LHS);
11184 auto UR = getUnsignedRange(RHS);
11185 if (CheckRanges(UL, UR))
11186 return true;
11187 auto *Diff = getMinusSCEV(LHS, RHS);
11188 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11189 }
11190
11191 if (CmpInst::isSigned(Pred)) {
11192 auto SL = getSignedRange(LHS);
11193 auto SR = getSignedRange(RHS);
11194 return CheckRanges(SL, SR);
11195 }
11196
11197 auto UL = getUnsignedRange(LHS);
11198 auto UR = getUnsignedRange(RHS);
11199 return CheckRanges(UL, UR);
11200}
11201
11202bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
11203 const SCEV *LHS,
11204 const SCEV *RHS) {
11205 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11206 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11207 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11208 // OutC1 and OutC2.
11209 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11210 APInt &OutC1, APInt &OutC2,
11211 SCEV::NoWrapFlags ExpectedFlags) {
11212 const SCEV *XNonConstOp, *XConstOp;
11213 const SCEV *YNonConstOp, *YConstOp;
11214 SCEV::NoWrapFlags XFlagsPresent;
11215 SCEV::NoWrapFlags YFlagsPresent;
11216
11217 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11218 XConstOp = getZero(X->getType());
11219 XNonConstOp = X;
11220 XFlagsPresent = ExpectedFlags;
11221 }
11222 if (!isa<SCEVConstant>(XConstOp) ||
11223 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11224 return false;
11225
11226 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11227 YConstOp = getZero(Y->getType());
11228 YNonConstOp = Y;
11229 YFlagsPresent = ExpectedFlags;
11230 }
11231
11232 if (!isa<SCEVConstant>(YConstOp) ||
11233 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11234 return false;
11235
11236 if (YNonConstOp != XNonConstOp)
11237 return false;
11238
11239 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11240 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11241
11242 return true;
11243 };
11244
11245 APInt C1;
11246 APInt C2;
11247
11248 switch (Pred) {
11249 default:
11250 break;
11251
11252 case ICmpInst::ICMP_SGE:
11253 std::swap(LHS, RHS);
11254 [[fallthrough]];
11255 case ICmpInst::ICMP_SLE:
11256 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11257 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11258 return true;
11259
11260 break;
11261
11262 case ICmpInst::ICMP_SGT:
11263 std::swap(LHS, RHS);
11264 [[fallthrough]];
11265 case ICmpInst::ICMP_SLT:
11266 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11267 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11268 return true;
11269
11270 break;
11271
11272 case ICmpInst::ICMP_UGE:
11273 std::swap(LHS, RHS);
11274 [[fallthrough]];
11275 case ICmpInst::ICMP_ULE:
11276 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11277 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
11278 return true;
11279
11280 break;
11281
11282 case ICmpInst::ICMP_UGT:
11283 std::swap(LHS, RHS);
11284 [[fallthrough]];
11285 case ICmpInst::ICMP_ULT:
11286 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11287 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
11288 return true;
11289 break;
11290 }
11291
11292 return false;
11293}
11294
11295bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
11296 const SCEV *LHS,
11297 const SCEV *RHS) {
11298 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11299 return false;
11300
11301 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11302 // the stack can result in exponential time complexity.
11303 SaveAndRestore Restore(ProvingSplitPredicate, true);
11304
11305 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11306 //
11307 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11308 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11309 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11310 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11311 // use isKnownPredicate later if needed.
11312 return isKnownNonNegative(RHS) &&
11315}
11316
11317bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
11319 const SCEV *LHS, const SCEV *RHS) {
11320 // No need to even try if we know the module has no guards.
11321 if (!HasGuards)
11322 return false;
11323
11324 return any_of(*BB, [&](const Instruction &I) {
11325 using namespace llvm::PatternMatch;
11326
11327 Value *Condition;
11328 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11329 m_Value(Condition))) &&
11330 isImpliedCond(Pred, LHS, RHS, Condition, false);
11331 });
11332}
11333
11334/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11335/// protected by a conditional between LHS and RHS. This is used to
11336/// to eliminate casts.
11337bool
11340 const SCEV *LHS, const SCEV *RHS) {
11341 // Interpret a null as meaning no loop, where there is obviously no guard
11342 // (interprocedural conditions notwithstanding). Do not bother about
11343 // unreachable loops.
11344 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11345 return true;
11346
11347 if (VerifyIR)
11348 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11349 "This cannot be done on broken IR!");
11350
11351
11352 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11353 return true;
11354
11355 BasicBlock *Latch = L->getLoopLatch();
11356 if (!Latch)
11357 return false;
11358
11359 BranchInst *LoopContinuePredicate =
11360 dyn_cast<BranchInst>(Latch->getTerminator());
11361 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11362 isImpliedCond(Pred, LHS, RHS,
11363 LoopContinuePredicate->getCondition(),
11364 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11365 return true;
11366
11367 // We don't want more than one activation of the following loops on the stack
11368 // -- that can lead to O(n!) time complexity.
11369 if (WalkingBEDominatingConds)
11370 return false;
11371
11372 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11373
11374 // See if we can exploit a trip count to prove the predicate.
11375 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11376 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11377 if (LatchBECount != getCouldNotCompute()) {
11378 // We know that Latch branches back to the loop header exactly
11379 // LatchBECount times. This means the backdege condition at Latch is
11380 // equivalent to "{0,+,1} u< LatchBECount".
11381 Type *Ty = LatchBECount->getType();
11382 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11383 const SCEV *LoopCounter =
11384 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11385 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11386 LatchBECount))
11387 return true;
11388 }
11389
11390 // Check conditions due to any @llvm.assume intrinsics.
11391 for (auto &AssumeVH : AC.assumptions()) {
11392 if (!AssumeVH)
11393 continue;
11394 auto *CI = cast<CallInst>(AssumeVH);
11395 if (!DT.dominates(CI, Latch->getTerminator()))
11396 continue;
11397
11398 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11399 return true;
11400 }
11401
11402 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11403 return true;
11404
11405 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11406 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11407 assert(DTN && "should reach the loop header before reaching the root!");
11408
11409 BasicBlock *BB = DTN->getBlock();
11410 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11411 return true;
11412
11413 BasicBlock *PBB = BB->getSinglePredecessor();
11414 if (!PBB)
11415 continue;
11416
11417 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11418 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11419 continue;
11420
11421 Value *Condition = ContinuePredicate->getCondition();
11422
11423 // If we have an edge `E` within the loop body that dominates the only
11424 // latch, the condition guarding `E` also guards the backedge. This
11425 // reasoning works only for loops with a single latch.
11426
11427 BasicBlockEdge DominatingEdge(PBB, BB);
11428 if (DominatingEdge.isSingleEdge()) {
11429 // We're constructively (and conservatively) enumerating edges within the
11430 // loop body that dominate the latch. The dominator tree better agree
11431 // with us on this:
11432 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11433
11434 if (isImpliedCond(Pred, LHS, RHS, Condition,
11435 BB != ContinuePredicate->getSuccessor(0)))
11436 return true;
11437 }
11438 }
11439
11440 return false;
11441}
11442
11445 const SCEV *LHS,
11446 const SCEV *RHS) {
11447 // Do not bother proving facts for unreachable code.
11448 if (!DT.isReachableFromEntry(BB))
11449 return true;
11450 if (VerifyIR)
11451 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11452 "This cannot be done on broken IR!");
11453
11454 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11455 // the facts (a >= b && a != b) separately. A typical situation is when the
11456 // non-strict comparison is known from ranges and non-equality is known from
11457 // dominating predicates. If we are proving strict comparison, we always try
11458 // to prove non-equality and non-strict comparison separately.
11459 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11460 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11461 bool ProvedNonStrictComparison = false;
11462 bool ProvedNonEquality = false;
11463
11464 auto SplitAndProve =
11465 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11466 if (!ProvedNonStrictComparison)
11467 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11468 if (!ProvedNonEquality)
11469 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11470 if (ProvedNonStrictComparison && ProvedNonEquality)
11471 return true;
11472 return false;
11473 };
11474
11475 if (ProvingStrictComparison) {
11476 auto ProofFn = [&](ICmpInst::Predicate P) {
11477 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11478 };
11479 if (SplitAndProve(ProofFn))
11480 return true;
11481 }
11482
11483 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11484 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11485 const Instruction *CtxI = &BB->front();
11486 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11487 return true;
11488 if (ProvingStrictComparison) {
11489 auto ProofFn = [&](ICmpInst::Predicate P) {
11490 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11491 };
11492 if (SplitAndProve(ProofFn))
11493 return true;
11494 }
11495 return false;
11496 };
11497
11498 // Starting at the block's predecessor, climb up the predecessor chain, as long
11499 // as there are predecessors that can be found that have unique successors
11500 // leading to the original block.
11501 const Loop *ContainingLoop = LI.getLoopFor(BB);
11502 const BasicBlock *PredBB;
11503 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11504 PredBB = ContainingLoop->getLoopPredecessor();
11505 else
11506 PredBB = BB->getSinglePredecessor();
11507 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11508 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11509 const BranchInst *BlockEntryPredicate =
11510 dyn_cast<BranchInst>(Pair.first->getTerminator());
11511 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11512 continue;
11513
11514 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11515 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11516 return true;
11517 }
11518
11519 // Check conditions due to any @llvm.assume intrinsics.
11520 for (auto &AssumeVH : AC.assumptions()) {
11521 if (!AssumeVH)
11522 continue;
11523 auto *CI = cast<CallInst>(AssumeVH);
11524 if (!DT.dominates(CI, BB))
11525 continue;
11526
11527 if (ProveViaCond(CI->getArgOperand(0), false))
11528 return true;
11529 }
11530
11531 // Check conditions due to any @llvm.experimental.guard intrinsics.
11532 auto *GuardDecl = F.getParent()->getFunction(
11533 Intrinsic::getName(Intrinsic::experimental_guard));
11534 if (GuardDecl)
11535 for (const auto *GU : GuardDecl->users())
11536 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11537 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11538 if (ProveViaCond(Guard->getArgOperand(0), false))
11539 return true;
11540 return false;
11541}
11542
11545 const SCEV *LHS,
11546 const SCEV *RHS) {
11547 // Interpret a null as meaning no loop, where there is obviously no guard
11548 // (interprocedural conditions notwithstanding).
11549 if (!L)
11550 return false;
11551
11552 // Both LHS and RHS must be available at loop entry.
11554 "LHS is not available at Loop Entry");
11556 "RHS is not available at Loop Entry");
11557
11558 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11559 return true;
11560
11561 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11562}
11563
11564bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11565 const SCEV *RHS,
11566 const Value *FoundCondValue, bool Inverse,
11567 const Instruction *CtxI) {
11568 // False conditions implies anything. Do not bother analyzing it further.
11569 if (FoundCondValue ==
11570 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11571 return true;
11572
11573 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11574 return false;
11575
11576 auto ClearOnExit =
11577 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11578
11579 // Recursively handle And and Or conditions.
11580 const Value *Op0, *Op1;
11581 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11582 if (!Inverse)
11583 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11584 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11585 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11586 if (Inverse)
11587 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11588 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11589 }
11590
11591 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11592 if (!ICI) return false;
11593
11594 // Now that we found a conditional branch that dominates the loop or controls
11595 // the loop latch. Check to see if it is the comparison we are looking for.
11596 ICmpInst::Predicate FoundPred;
11597 if (Inverse)
11598 FoundPred = ICI->getInversePredicate();
11599 else
11600 FoundPred = ICI->getPredicate();
11601
11602 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11603 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11604
11605 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11606}
11607
11608bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11609 const SCEV *RHS,
11610 ICmpInst::Predicate FoundPred,
11611 const SCEV *FoundLHS, const SCEV *FoundRHS,
11612 const Instruction *CtxI) {
11613 // Balance the types.
11614 if (getTypeSizeInBits(LHS->getType()) <
11615 getTypeSizeInBits(FoundLHS->getType())) {
11616 // For unsigned and equality predicates, try to prove that both found
11617 // operands fit into narrow unsigned range. If so, try to prove facts in
11618 // narrow types.
11619 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11620 !FoundRHS->getType()->isPointerTy()) {
11621 auto *NarrowType = LHS->getType();
11622 auto *WideType = FoundLHS->getType();
11623 auto BitWidth = getTypeSizeInBits(NarrowType);
11624 const SCEV *MaxValue = getZeroExtendExpr(
11626 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11627 MaxValue) &&
11628 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11629 MaxValue)) {
11630 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11631 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11632 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11633 TruncFoundRHS, CtxI))
11634 return true;
11635 }
11636 }
11637
11638 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11639 return false;
11640 if (CmpInst::isSigned(Pred)) {
11641 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11642 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11643 } else {
11644 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11645 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11646 }
11647 } else if (getTypeSizeInBits(LHS->getType()) >
11648 getTypeSizeInBits(FoundLHS->getType())) {
11649 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11650 return false;
11651 if (CmpInst::isSigned(FoundPred)) {
11652 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11653 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11654 } else {
11655 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11656 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11657 }
11658 }
11659 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11660 FoundRHS, CtxI);
11661}
11662
11663bool ScalarEvolution::isImpliedCondBalancedTypes(
11664 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11665 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11666 const Instruction *CtxI) {
11668 getTypeSizeInBits(FoundLHS->getType()) &&
11669 "Types should be balanced!");
11670 // Canonicalize the query to match the way instcombine will have
11671 // canonicalized the comparison.
11672 if (SimplifyICmpOperands(Pred, LHS, RHS))
11673 if (LHS == RHS)
11674 return CmpInst::isTrueWhenEqual(Pred);
11675 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11676 if (FoundLHS == FoundRHS)
11677 return CmpInst::isFalseWhenEqual(FoundPred);
11678
11679 // Check to see if we can make the LHS or RHS match.
11680 if (LHS == FoundRHS || RHS == FoundLHS) {
11681 if (isa<SCEVConstant>(RHS)) {
11682 std::swap(FoundLHS, FoundRHS);
11683 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11684 } else {
11685 std::swap(LHS, RHS);
11686 Pred = ICmpInst::getSwappedPredicate(Pred);
11687 }
11688 }
11689
11690 // Check whether the found predicate is the same as the desired predicate.
11691 if (FoundPred == Pred)
11692 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11693
11694 // Check whether swapping the found predicate makes it the same as the
11695 // desired predicate.
11696 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11697 // We can write the implication
11698 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11699 // using one of the following ways:
11700 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11701 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11702 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11703 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11704 // Forms 1. and 2. require swapping the operands of one condition. Don't
11705 // do this if it would break canonical constant/addrec ordering.
11706 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11707 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11708 CtxI);
11709 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11710 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11711
11712 // There's no clear preference between forms 3. and 4., try both. Avoid
11713 // forming getNotSCEV of pointer values as the resulting subtract is
11714 // not legal.
11715 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11716 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11717 FoundLHS, FoundRHS, CtxI))
11718 return true;
11719
11720 if (!FoundLHS->getType()->isPointerTy() &&
11721 !FoundRHS->getType()->isPointerTy() &&
11722 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11723 getNotSCEV(FoundRHS), CtxI))
11724 return true;
11725
11726 return false;
11727 }
11728
11729 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11730 CmpInst::Predicate P2) {
11731 assert(P1 != P2 && "Handled earlier!");
11732 return CmpInst::isRelational(P2) &&
11734 };
11735 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11736 // Unsigned comparison is the same as signed comparison when both the
11737 // operands are non-negative or negative.
11738 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11739 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11740 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11741 // Create local copies that we can freely swap and canonicalize our
11742 // conditions to "le/lt".
11743 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11744 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11745 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11746 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11747 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11748 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11749 std::swap(CanonicalLHS, CanonicalRHS);
11750 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11751 }
11752 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11753 "Must be!");
11754 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11755 ICmpInst::isLE(CanonicalFoundPred)) &&
11756 "Must be!");
11757 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11758 // Use implication:
11759 // x <u y && y >=s 0 --> x <s y.
11760 // If we can prove the left part, the right part is also proven.
11761 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11762 CanonicalRHS, CanonicalFoundLHS,
11763 CanonicalFoundRHS);
11764 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11765 // Use implication:
11766 // x <s y && y <s 0 --> x <u y.
11767 // If we can prove the left part, the right part is also proven.
11768 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11769 CanonicalRHS, CanonicalFoundLHS,
11770 CanonicalFoundRHS);
11771 }
11772
11773 // Check if we can make progress by sharpening ranges.
11774 if (FoundPred == ICmpInst::ICMP_NE &&
11775 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11776
11777 const SCEVConstant *C = nullptr;
11778 const SCEV *V = nullptr;
11779
11780 if (isa<SCEVConstant>(FoundLHS)) {
11781 C = cast<SCEVConstant>(FoundLHS);
11782 V = FoundRHS;
11783 } else {
11784 C = cast<SCEVConstant>(FoundRHS);
11785 V = FoundLHS;
11786 }
11787
11788 // The guarding predicate tells us that C != V. If the known range
11789 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11790 // range we consider has to correspond to same signedness as the
11791 // predicate we're interested in folding.
11792
11793 APInt Min = ICmpInst::isSigned(Pred) ?
11795
11796 if (Min == C->getAPInt()) {
11797 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11798 // This is true even if (Min + 1) wraps around -- in case of
11799 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11800
11801 APInt SharperMin = Min + 1;
11802
11803 switch (Pred) {
11804 case ICmpInst::ICMP_SGE:
11805 case ICmpInst::ICMP_UGE:
11806 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11807 // RHS, we're done.
11808 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11809 CtxI))
11810 return true;
11811 [[fallthrough]];
11812
11813 case ICmpInst::ICMP_SGT:
11814 case ICmpInst::ICMP_UGT:
11815 // We know from the range information that (V `Pred` Min ||
11816 // V == Min). We know from the guarding condition that !(V
11817 // == Min). This gives us
11818 //
11819 // V `Pred` Min || V == Min && !(V == Min)
11820 // => V `Pred` Min
11821 //
11822 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11823
11824 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11825 return true;
11826 break;
11827
11828 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11829 case ICmpInst::ICMP_SLE:
11830 case ICmpInst::ICMP_ULE:
11831 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11832 LHS, V, getConstant(SharperMin), CtxI))
11833 return true;
11834 [[fallthrough]];
11835
11836 case ICmpInst::ICMP_SLT:
11837 case ICmpInst::ICMP_ULT:
11838 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11839 LHS, V, getConstant(Min), CtxI))
11840 return true;
11841 break;
11842
11843 default:
11844 // No change
11845 break;
11846 }
11847 }
11848 }
11849
11850 // Check whether the actual condition is beyond sufficient.
11851 if (FoundPred == ICmpInst::ICMP_EQ)
11852 if (ICmpInst::isTrueWhenEqual(Pred))
11853 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11854 return true;
11855 if (Pred == ICmpInst::ICMP_NE)
11856 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11857 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11858 return true;
11859
11860 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
11861 return true;
11862
11863 // Otherwise assume the worst.
11864 return false;
11865}
11866
11867bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11868 const SCEV *&L, const SCEV *&R,
11869 SCEV::NoWrapFlags &Flags) {
11870 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11871 if (!AE || AE->getNumOperands() != 2)
11872 return false;
11873
11874 L = AE->getOperand(0);
11875 R = AE->getOperand(1);
11876 Flags = AE->getNoWrapFlags();
11877 return true;
11878}
11879
11880std::optional<APInt>
11882 // We avoid subtracting expressions here because this function is usually
11883 // fairly deep in the call stack (i.e. is called many times).
11884
11885 // X - X = 0.
11886 if (More == Less)
11887 return APInt(getTypeSizeInBits(More->getType()), 0);
11888
11889 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11890 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11891 const auto *MAR = cast<SCEVAddRecExpr>(More);
11892
11893 if (LAR->getLoop() != MAR->getLoop())
11894 return std::nullopt;
11895
11896 // We look at affine expressions only; not for correctness but to keep
11897 // getStepRecurrence cheap.
11898 if (!LAR->isAffine() || !MAR->isAffine())
11899 return std::nullopt;
11900
11901 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11902 return std::nullopt;
11903
11904 Less = LAR->getStart();
11905 More = MAR->getStart();
11906
11907 // fall through
11908 }
11909
11910 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11911 const auto &M = cast<SCEVConstant>(More)->getAPInt();
11912 const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11913 return M - L;
11914 }
11915
11916 SCEV::NoWrapFlags Flags;
11917 const SCEV *LLess = nullptr, *RLess = nullptr;
11918 const SCEV *LMore = nullptr, *RMore = nullptr;
11919 const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11920 // Compare (X + C1) vs X.
11921 if (splitBinaryAdd(Less, LLess, RLess, Flags))
11922 if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11923 if (RLess == More)
11924 return -(C1->getAPInt());
11925
11926 // Compare X vs (X + C2).
11927 if (splitBinaryAdd(More, LMore, RMore, Flags))
11928 if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11929 if (RMore == Less)
11930 return C2->getAPInt();
11931
11932 // Compare (X + C1) vs (X + C2).
11933 if (C1 && C2 && RLess == RMore)
11934 return C2->getAPInt() - C1->getAPInt();
11935
11936 return std::nullopt;
11937}
11938
11939bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
11940 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11941 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11942 // Try to recognize the following pattern:
11943 //
11944 // FoundRHS = ...
11945 // ...
11946 // loop:
11947 // FoundLHS = {Start,+,W}
11948 // context_bb: // Basic block from the same loop
11949 // known(Pred, FoundLHS, FoundRHS)
11950 //
11951 // If some predicate is known in the context of a loop, it is also known on
11952 // each iteration of this loop, including the first iteration. Therefore, in
11953 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
11954 // prove the original pred using this fact.
11955 if (!CtxI)
11956 return false;
11957 const BasicBlock *ContextBB = CtxI->getParent();
11958 // Make sure AR varies in the context block.
11959 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
11960 const Loop *L = AR->getLoop();
11961 // Make sure that context belongs to the loop and executes on 1st iteration
11962 // (if it ever executes at all).
11963 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11964 return false;
11965 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
11966 return false;
11967 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
11968 }
11969
11970 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
11971 const Loop *L = AR->getLoop();
11972 // Make sure that context belongs to the loop and executes on 1st iteration
11973 // (if it ever executes at all).
11974 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11975 return false;
11976 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
11977 return false;
11978 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
11979 }
11980
11981 return false;
11982}
11983
11984bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
11985 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11986 const SCEV *FoundLHS, const SCEV *FoundRHS) {
11987 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
11988 return false;
11989
11990 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11991 if (!AddRecLHS)
11992 return false;
11993
11994 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
11995 if (!AddRecFoundLHS)
11996 return false;
11997
11998 // We'd like to let SCEV reason about control dependencies, so we constrain
11999 // both the inequalities to be about add recurrences on the same loop. This
12000 // way we can use isLoopEntryGuardedByCond later.
12001
12002 const Loop *L = AddRecFoundLHS->getLoop();
12003 if (L != AddRecLHS->getLoop())
12004 return false;
12005
12006 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12007 //
12008 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12009 // ... (2)
12010 //
12011 // Informal proof for (2), assuming (1) [*]:
12012 //
12013 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12014 //
12015 // Then
12016 //
12017 // FoundLHS s< FoundRHS s< INT_MIN - C
12018 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12019 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12020 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12021 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12022 // <=> FoundLHS + C s< FoundRHS + C
12023 //
12024 // [*]: (1) can be proved by ruling out overflow.
12025 //
12026 // [**]: This can be proved by analyzing all the four possibilities:
12027 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12028 // (A s>= 0, B s>= 0).
12029 //
12030 // Note:
12031 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12032 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12033 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12034 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12035 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12036 // C)".
12037
12038 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12039 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12040 if (!LDiff || !RDiff || *LDiff != *RDiff)
12041 return false;
12042
12043 if (LDiff->isMinValue())
12044 return true;
12045
12046 APInt FoundRHSLimit;
12047
12048 if (Pred == CmpInst::ICMP_ULT) {
12049 FoundRHSLimit = -(*RDiff);
12050 } else {
12051 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12052 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12053 }
12054
12055 // Try to prove (1) or (2), as needed.
12056 return isAvailableAtLoopEntry(FoundRHS, L) &&
12057 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12058 getConstant(FoundRHSLimit));
12059}
12060
12061bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
12062 const SCEV *LHS, const SCEV *RHS,
12063 const SCEV *FoundLHS,
12064 const SCEV *FoundRHS, unsigned Depth) {
12065 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12066
12067 auto ClearOnExit = make_scope_exit([&]() {
12068 if (LPhi) {
12069 bool Erased = PendingMerges.erase(LPhi);
12070 assert(Erased && "Failed to erase LPhi!");
12071 (void)Erased;
12072 }
12073 if (RPhi) {
12074 bool Erased = PendingMerges.erase(RPhi);
12075 assert(Erased && "Failed to erase RPhi!");
12076 (void)Erased;
12077 }
12078 });
12079
12080 // Find respective Phis and check that they are not being pending.
12081 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12082 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12083 if (!PendingMerges.insert(Phi).second)
12084 return false;
12085 LPhi = Phi;
12086 }
12087 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12088 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12089 // If we detect a loop of Phi nodes being processed by this method, for
12090 // example:
12091 //
12092 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12093 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12094 //
12095 // we don't want to deal with a case that complex, so return conservative
12096 // answer false.
12097 if (!PendingMerges.insert(Phi).second)
12098 return false;
12099 RPhi = Phi;
12100 }
12101
12102 // If none of LHS, RHS is a Phi, nothing to do here.
12103 if (!LPhi && !RPhi)
12104 return false;
12105
12106 // If there is a SCEVUnknown Phi we are interested in, make it left.
12107 if (!LPhi) {
12108 std::swap(LHS, RHS);
12109 std::swap(FoundLHS, FoundRHS);
12110 std::swap(LPhi, RPhi);
12111 Pred = ICmpInst::getSwappedPredicate(Pred);
12112 }
12113
12114 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12115 const BasicBlock *LBB = LPhi->getParent();
12116 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12117
12118 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12119 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12120 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12121 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12122 };
12123
12124 if (RPhi && RPhi->getParent() == LBB) {
12125 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12126 // If we compare two Phis from the same block, and for each entry block
12127 // the predicate is true for incoming values from this block, then the
12128 // predicate is also true for the Phis.
12129 for (const BasicBlock *IncBB : predecessors(LBB)) {
12130 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12131 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12132 if (!ProvedEasily(L, R))
12133 return false;
12134 }
12135 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12136 // Case two: RHS is also a Phi from the same basic block, and it is an
12137 // AddRec. It means that there is a loop which has both AddRec and Unknown
12138 // PHIs, for it we can compare incoming values of AddRec from above the loop
12139 // and latch with their respective incoming values of LPhi.
12140 // TODO: Generalize to handle loops with many inputs in a header.
12141 if (LPhi->getNumIncomingValues() != 2) return false;
12142
12143 auto *RLoop = RAR->getLoop();
12144 auto *Predecessor = RLoop->getLoopPredecessor();
12145 assert(Predecessor && "Loop with AddRec with no predecessor?");
12146 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12147 if (!ProvedEasily(L1, RAR->getStart()))
12148 return false;
12149 auto *Latch = RLoop->getLoopLatch();
12150 assert(Latch && "Loop with AddRec with no latch?");
12151 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12152 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12153 return false;
12154 } else {
12155 // In all other cases go over inputs of LHS and compare each of them to RHS,
12156 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12157 // At this point RHS is either a non-Phi, or it is a Phi from some block
12158 // different from LBB.
12159 for (const BasicBlock *IncBB : predecessors(LBB)) {
12160 // Check that RHS is available in this block.
12161 if (!dominates(RHS, IncBB))
12162 return false;
12163 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12164 // Make sure L does not refer to a value from a potentially previous
12165 // iteration of a loop.
12166 if (!properlyDominates(L, LBB))
12167 return false;
12168 if (!ProvedEasily(L, RHS))
12169 return false;
12170 }
12171 }
12172 return true;
12173}
12174
12175bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
12176 const SCEV *LHS,
12177 const SCEV *RHS,
12178 const SCEV *FoundLHS,
12179 const SCEV *FoundRHS) {
12180 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12181 // sure that we are dealing with same LHS.
12182 if (RHS == FoundRHS) {
12183 std::swap(LHS, RHS);
12184 std::swap(FoundLHS, FoundRHS);
12185 Pred = ICmpInst::getSwappedPredicate(Pred);
12186 }
12187 if (LHS != FoundLHS)
12188 return false;
12189
12190 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12191 if (!SUFoundRHS)
12192 return false;
12193
12194 Value *Shiftee, *ShiftValue;
12195
12196 using namespace PatternMatch;
12197 if (match(SUFoundRHS->getValue(),
12198 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12199 auto *ShifteeS = getSCEV(Shiftee);
12200 // Prove one of the following:
12201 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12202 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12203 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12204 // ---> LHS <s RHS
12205 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12206 // ---> LHS <=s RHS
12207 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12208 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12209 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12210 if (isKnownNonNegative(ShifteeS))
12211 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12212 }
12213
12214 return false;
12215}
12216
12217bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
12218 const SCEV *LHS, const SCEV *RHS,
12219 const SCEV *FoundLHS,
12220 const SCEV *FoundRHS,
12221 const Instruction *CtxI) {
12222 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12223 return true;
12224
12225 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12226 return true;
12227
12228 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12229 return true;
12230
12231 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12232 CtxI))
12233 return true;
12234
12235 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12236 FoundLHS, FoundRHS);
12237}
12238
12239/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12240template <typename MinMaxExprType>
12241static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12242 const SCEV *Candidate) {
12243 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12244 if (!MinMaxExpr)
12245 return false;
12246
12247 return is_contained(MinMaxExpr->operands(), Candidate);
12248}
12249
12252 const SCEV *LHS, const SCEV *RHS) {
12253 // If both sides are affine addrecs for the same loop, with equal
12254 // steps, and we know the recurrences don't wrap, then we only
12255 // need to check the predicate on the starting values.
12256
12257 if (!ICmpInst::isRelational(Pred))
12258 return false;
12259
12260 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12261 if (!LAR)
12262 return false;
12263 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12264 if (!RAR)
12265 return false;
12266 if (LAR->getLoop() != RAR->getLoop())
12267 return false;
12268 if (!LAR->isAffine() || !RAR->isAffine())
12269 return false;
12270
12271 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12272 return false;
12273
12276 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12277 return false;
12278
12279 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12280}
12281
12282/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12283/// expression?
12286 const SCEV *LHS, const SCEV *RHS) {
12287 switch (Pred) {
12288 default:
12289 return false;
12290
12291 case ICmpInst::ICMP_SGE:
12292 std::swap(LHS, RHS);
12293 [[fallthrough]];
12294 case ICmpInst::ICMP_SLE:
12295 return
12296 // min(A, ...) <= A
12297 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12298 // A <= max(A, ...)
12299 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12300
12301 case ICmpInst::ICMP_UGE:
12302 std::swap(LHS, RHS);
12303 [[fallthrough]];
12304 case ICmpInst::ICMP_ULE:
12305 return
12306 // min(A, ...) <= A
12307 // FIXME: what about umin_seq?
12308 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12309 // A <= max(A, ...)
12310 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12311 }
12312
12313 llvm_unreachable("covered switch fell through?!");
12314}
12315
12316bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
12317 const SCEV *LHS, const SCEV *RHS,
12318 const SCEV *FoundLHS,
12319 const SCEV *FoundRHS,
12320 unsigned Depth) {
12323 "LHS and RHS have different sizes?");
12324 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12325 getTypeSizeInBits(FoundRHS->getType()) &&
12326 "FoundLHS and FoundRHS have different sizes?");
12327 // We want to avoid hurting the compile time with analysis of too big trees.
12329 return false;
12330
12331 // We only want to work with GT comparison so far.
12332 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12333 Pred = CmpInst::getSwappedPredicate(Pred);
12334 std::swap(LHS, RHS);
12335 std::swap(FoundLHS, FoundRHS);
12336 }
12337
12338 // For unsigned, try to reduce it to corresponding signed comparison.
12339 if (Pred == ICmpInst::ICMP_UGT)
12340 // We can replace unsigned predicate with its signed counterpart if all
12341 // involved values are non-negative.
12342 // TODO: We could have better support for unsigned.
12343 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12344 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12345 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12346 // use this fact to prove that LHS and RHS are non-negative.
12347 const SCEV *MinusOne = getMinusOne(LHS->getType());
12348 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12349 FoundRHS) &&
12350 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12351 FoundRHS))
12352 Pred = ICmpInst::ICMP_SGT;
12353 }
12354
12355 if (Pred != ICmpInst::ICMP_SGT)
12356 return false;
12357
12358 auto GetOpFromSExt = [&](const SCEV *S) {
12359 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12360 return Ext->getOperand();
12361 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12362 // the constant in some cases.
12363 return S;
12364 };
12365
12366 // Acquire values from extensions.
12367 auto *OrigLHS = LHS;
12368 auto *OrigFoundLHS = FoundLHS;
12369 LHS = GetOpFromSExt(LHS);
12370 FoundLHS = GetOpFromSExt(FoundLHS);
12371
12372 // Is the SGT predicate can be proved trivially or using the found context.
12373 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12374 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12375 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12376 FoundRHS, Depth + 1);
12377 };
12378
12379 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12380 // We want to avoid creation of any new non-constant SCEV. Since we are
12381 // going to compare the operands to RHS, we should be certain that we don't
12382 // need any size extensions for this. So let's decline all cases when the
12383 // sizes of types of LHS and RHS do not match.
12384 // TODO: Maybe try to get RHS from sext to catch more cases?
12386 return false;
12387
12388 // Should not overflow.
12389 if (!LHSAddExpr->hasNoSignedWrap())
12390 return false;
12391
12392 auto *LL = LHSAddExpr->getOperand(0);
12393 auto *LR = LHSAddExpr->getOperand(1);
12394 auto *MinusOne = getMinusOne(RHS->getType());
12395
12396 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12397 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12398 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12399 };
12400 // Try to prove the following rule:
12401 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12402 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12403 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12404 return true;
12405 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12406 Value *LL, *LR;
12407 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12408
12409 using namespace llvm::PatternMatch;
12410
12411 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12412 // Rules for division.
12413 // We are going to perform some comparisons with Denominator and its
12414 // derivative expressions. In general case, creating a SCEV for it may
12415 // lead to a complex analysis of the entire graph, and in particular it
12416 // can request trip count recalculation for the same loop. This would
12417 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12418 // this, we only want to create SCEVs that are constants in this section.
12419 // So we bail if Denominator is not a constant.
12420 if (!isa<ConstantInt>(LR))
12421 return false;
12422
12423 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12424
12425 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12426 // then a SCEV for the numerator already exists and matches with FoundLHS.
12427 auto *Numerator = getExistingSCEV(LL);
12428 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12429 return false;
12430
12431 // Make sure that the numerator matches with FoundLHS and the denominator
12432 // is positive.
12433 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12434 return false;
12435
12436 auto *DTy = Denominator->getType();
12437 auto *FRHSTy = FoundRHS->getType();
12438 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12439 // One of types is a pointer and another one is not. We cannot extend
12440 // them properly to a wider type, so let us just reject this case.
12441 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12442 // to avoid this check.
12443 return false;
12444
12445 // Given that:
12446 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12447 auto *WTy = getWiderType(DTy, FRHSTy);
12448 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12449 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12450
12451 // Try to prove the following rule:
12452 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12453 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12454 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12455 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12456 if (isKnownNonPositive(RHS) &&
12457 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12458 return true;
12459
12460 // Try to prove the following rule:
12461 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12462 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12463 // If we divide it by Denominator > 2, then:
12464 // 1. If FoundLHS is negative, then the result is 0.
12465 // 2. If FoundLHS is non-negative, then the result is non-negative.
12466 // Anyways, the result is non-negative.
12467 auto *MinusOne = getMinusOne(WTy);
12468 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12469 if (isKnownNegative(RHS) &&
12470 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12471 return true;
12472 }
12473 }
12474
12475 // If our expression contained SCEVUnknown Phis, and we split it down and now
12476 // need to prove something for them, try to prove the predicate for every
12477 // possible incoming values of those Phis.
12478 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12479 return true;
12480
12481 return false;
12482}
12483
12485 const SCEV *LHS, const SCEV *RHS) {
12486 // zext x u<= sext x, sext x s<= zext x
12487 switch (Pred) {
12488 case ICmpInst::ICMP_SGE:
12489 std::swap(LHS, RHS);
12490 [[fallthrough]];
12491 case ICmpInst::ICMP_SLE: {
12492 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12493 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12494 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12495 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12496 return true;
12497 break;
12498 }
12499 case ICmpInst::ICMP_UGE:
12500 std::swap(LHS, RHS);
12501 [[fallthrough]];
12502 case ICmpInst::ICMP_ULE: {
12503 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12504 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12505 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12506 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12507 return true;
12508 break;
12509 }
12510 default:
12511 break;
12512 };
12513 return false;
12514}
12515
12516bool
12517ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12518 const SCEV *LHS, const SCEV *RHS) {
12519 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12520 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12521 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12522 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12523 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12524}
12525
12526bool
12527ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12528 const SCEV *LHS, const SCEV *RHS,
12529 const SCEV *FoundLHS,
12530 const SCEV *FoundRHS) {
12531 switch (Pred) {
12532 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
12533 case ICmpInst::ICMP_EQ:
12534 case ICmpInst::ICMP_NE:
12535 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12536 return true;
12537 break;
12538 case ICmpInst::ICMP_SLT:
12539 case ICmpInst::ICMP_SLE:
12540 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12541 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12542 return true;
12543 break;
12544 case ICmpInst::ICMP_SGT:
12545 case ICmpInst::ICMP_SGE:
12546 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12547 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12548 return true;
12549 break;
12550 case ICmpInst::ICMP_ULT:
12551 case ICmpInst::ICMP_ULE:
12552 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12553 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12554 return true;
12555 break;
12556 case ICmpInst::ICMP_UGT:
12557 case ICmpInst::ICMP_UGE:
12558 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12559 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12560 return true;
12561 break;
12562 }
12563
12564 // Maybe it can be proved via operations?
12565 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12566 return true;
12567
12568 return false;
12569}
12570
12571bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12572 const SCEV *LHS,
12573 const SCEV *RHS,
12574 ICmpInst::Predicate FoundPred,
12575 const SCEV *FoundLHS,
12576 const SCEV *FoundRHS) {
12577 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12578 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12579 // reduce the compile time impact of this optimization.
12580 return false;
12581
12582 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12583 if (!Addend)
12584 return false;
12585
12586 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12587
12588 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12589 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12590 ConstantRange FoundLHSRange =
12591 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12592
12593 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12594 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12595
12596 // We can also compute the range of values for `LHS` that satisfy the
12597 // consequent, "`LHS` `Pred` `RHS`":
12598 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12599 // The antecedent implies the consequent if every value of `LHS` that
12600 // satisfies the antecedent also satisfies the consequent.
12601 return LHSRange.icmp(Pred, ConstRHS);
12602}
12603
12604bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12605 bool IsSigned) {
12606 assert(isKnownPositive(Stride) && "Positive stride expected!");
12607
12608 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12609 const SCEV *One = getOne(Stride->getType());
12610
12611 if (IsSigned) {
12612 APInt MaxRHS = getSignedRangeMax(RHS);
12614 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12615
12616 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12617 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12618 }
12619
12620 APInt MaxRHS = getUnsignedRangeMax(RHS);
12621 APInt MaxValue = APInt::getMaxValue(BitWidth);
12622 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12623
12624 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12625 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12626}
12627
12628bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12629 bool IsSigned) {
12630
12631 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12632 const SCEV *One = getOne(Stride->getType());
12633
12634 if (IsSigned) {
12635 APInt MinRHS = getSignedRangeMin(RHS);
12637 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12638
12639 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12640 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12641 }
12642
12643 APInt MinRHS = getUnsignedRangeMin(RHS);
12644 APInt MinValue = APInt::getMinValue(BitWidth);
12645 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12646
12647 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12648 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12649}
12650
12652 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12653 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12654 // expression fixes the case of N=0.
12655 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12656 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12657 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12658}
12659
12660const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12661 const SCEV *Stride,
12662 const SCEV *End,
12663 unsigned BitWidth,
12664 bool IsSigned) {
12665 // The logic in this function assumes we can represent a positive stride.
12666 // If we can't, the backedge-taken count must be zero.
12667 if (IsSigned && BitWidth == 1)
12668 return getZero(Stride->getType());
12669
12670 // This code below only been closely audited for negative strides in the
12671 // unsigned comparison case, it may be correct for signed comparison, but
12672 // that needs to be established.
12673 if (IsSigned && isKnownNegative(Stride))
12674 return getCouldNotCompute();
12675
12676 // Calculate the maximum backedge count based on the range of values
12677 // permitted by Start, End, and Stride.
12678 APInt MinStart =
12679 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12680
12681 APInt MinStride =
12682 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12683
12684 // We assume either the stride is positive, or the backedge-taken count
12685 // is zero. So force StrideForMaxBECount to be at least one.
12686 APInt One(BitWidth, 1);
12687 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12688 : APIntOps::umax(One, MinStride);
12689
12690 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12691 : APInt::getMaxValue(BitWidth);
12692 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12693
12694 // Although End can be a MAX expression we estimate MaxEnd considering only
12695 // the case End = RHS of the loop termination condition. This is safe because
12696 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12697 // taken count.
12698 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12699 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12700
12701 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12702 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12703 : APIntOps::umax(MaxEnd, MinStart);
12704
12705 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12706 getConstant(StrideForMaxBECount) /* Step */);
12707}
12708
12710ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12711 const Loop *L, bool IsSigned,
12712 bool ControlsOnlyExit, bool AllowPredicates) {
12714
12715 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12716 bool PredicatedIV = false;
12717
12718 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
12719 // Can we prove this loop *must* be UB if overflow of IV occurs?
12720 // Reasoning goes as follows:
12721 // * Suppose the IV did self wrap.
12722 // * If Stride evenly divides the iteration space, then once wrap
12723 // occurs, the loop must revisit the same values.
12724 // * We know that RHS is invariant, and that none of those values
12725 // caused this exit to be taken previously. Thus, this exit is
12726 // dynamically dead.
12727 // * If this is the sole exit, then a dead exit implies the loop
12728 // must be infinite if there are no abnormal exits.
12729 // * If the loop were infinite, then it must either not be mustprogress
12730 // or have side effects. Otherwise, it must be UB.
12731 // * It can't (by assumption), be UB so we have contradicted our
12732 // premise and can conclude the IV did not in fact self-wrap.
12733 if (!isLoopInvariant(RHS, L))
12734 return false;
12735
12736 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12737 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12738 return false;
12739
12740 if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
12741 return false;
12742
12743 return loopIsFiniteByAssumption(L);
12744 };
12745
12746 if (!IV) {
12747 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12748 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12749 if (AR && AR->getLoop() == L && AR->isAffine()) {
12750 auto canProveNUW = [&]() {
12751 // We can use the comparison to infer no-wrap flags only if it fully
12752 // controls the loop exit.
12753 if (!ControlsOnlyExit)
12754 return false;
12755
12756 if (!isLoopInvariant(RHS, L))
12757 return false;
12758
12759 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12760 // We need the sequence defined by AR to strictly increase in the
12761 // unsigned integer domain for the logic below to hold.
12762 return false;
12763
12764 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12765 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12766 // If RHS <=u Limit, then there must exist a value V in the sequence
12767 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12768 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12769 // overflow occurs. This limit also implies that a signed comparison
12770 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12771 // the high bits on both sides must be zero.
12772 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12773 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12774 Limit = Limit.zext(OuterBitWidth);
12775 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12776 };
12777 auto Flags = AR->getNoWrapFlags();
12778 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12779 Flags = setFlags(Flags, SCEV::FlagNUW);
12780
12781 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12782 if (AR->hasNoUnsignedWrap()) {
12783 // Emulate what getZeroExtendExpr would have done during construction
12784 // if we'd been able to infer the fact just above at that time.
12785 const SCEV *Step = AR->getStepRecurrence(*this);
12786 Type *Ty = ZExt->getType();
12787 auto *S = getAddRecExpr(
12788 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12789 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12790 IV = dyn_cast<SCEVAddRecExpr>(S);
12791 }
12792 }
12793 }
12794 }
12795
12796
12797 if (!IV && AllowPredicates) {
12798 // Try to make this an AddRec using runtime tests, in the first X
12799 // iterations of this loop, where X is the SCEV expression found by the
12800 // algorithm below.
12801 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12802 PredicatedIV = true;
12803 }
12804
12805 // Avoid weird loops
12806 if (!IV || IV->getLoop() != L || !IV->isAffine())
12807 return getCouldNotCompute();
12808
12809 // A precondition of this method is that the condition being analyzed
12810 // reaches an exiting branch which dominates the latch. Given that, we can
12811 // assume that an increment which violates the nowrap specification and
12812 // produces poison must cause undefined behavior when the resulting poison
12813 // value is branched upon and thus we can conclude that the backedge is
12814 // taken no more often than would be required to produce that poison value.
12815 // Note that a well defined loop can exit on the iteration which violates
12816 // the nowrap specification if there is another exit (either explicit or
12817 // implicit/exceptional) which causes the loop to execute before the
12818 // exiting instruction we're analyzing would trigger UB.
12819 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12820 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
12822
12823 const SCEV *Stride = IV->getStepRecurrence(*this);
12824
12825 bool PositiveStride = isKnownPositive(Stride);
12826
12827 // Avoid negative or zero stride values.
12828 if (!PositiveStride) {
12829 // We can compute the correct backedge taken count for loops with unknown
12830 // strides if we can prove that the loop is not an infinite loop with side
12831 // effects. Here's the loop structure we are trying to handle -
12832 //
12833 // i = start
12834 // do {
12835 // A[i] = i;
12836 // i += s;
12837 // } while (i < end);
12838 //
12839 // The backedge taken count for such loops is evaluated as -
12840 // (max(end, start + stride) - start - 1) /u stride
12841 //
12842 // The additional preconditions that we need to check to prove correctness
12843 // of the above formula is as follows -
12844 //
12845 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12846 // NoWrap flag).
12847 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12848 // no side effects within the loop)
12849 // c) loop has a single static exit (with no abnormal exits)
12850 //
12851 // Precondition a) implies that if the stride is negative, this is a single
12852 // trip loop. The backedge taken count formula reduces to zero in this case.
12853 //
12854 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12855 // then a zero stride means the backedge can't be taken without executing
12856 // undefined behavior.
12857 //
12858 // The positive stride case is the same as isKnownPositive(Stride) returning
12859 // true (original behavior of the function).
12860 //
12861 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12863 return getCouldNotCompute();
12864
12865 if (!isKnownNonZero(Stride)) {
12866 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12867 // if it might eventually be greater than start and if so, on which
12868 // iteration. We can't even produce a useful upper bound.
12869 if (!isLoopInvariant(RHS, L))
12870 return getCouldNotCompute();
12871
12872 // We allow a potentially zero stride, but we need to divide by stride
12873 // below. Since the loop can't be infinite and this check must control
12874 // the sole exit, we can infer the exit must be taken on the first
12875 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12876 // we know the numerator in the divides below must be zero, so we can
12877 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12878 // and produce the right result.
12879 // FIXME: Handle the case where Stride is poison?
12880 auto wouldZeroStrideBeUB = [&]() {
12881 // Proof by contradiction. Suppose the stride were zero. If we can
12882 // prove that the backedge *is* taken on the first iteration, then since
12883 // we know this condition controls the sole exit, we must have an
12884 // infinite loop. We can't have a (well defined) infinite loop per
12885 // check just above.
12886 // Note: The (Start - Stride) term is used to get the start' term from
12887 // (start' + stride,+,stride). Remember that we only care about the
12888 // result of this expression when stride == 0 at runtime.
12889 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12890 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12891 };
12892 if (!wouldZeroStrideBeUB()) {
12893 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12894 }
12895 }
12896 } else if (!Stride->isOne() && !NoWrap) {
12897 auto isUBOnWrap = [&]() {
12898 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
12899 // follows trivially from the fact that every (un)signed-wrapped, but
12900 // not self-wrapped value must be LT than the last value before
12901 // (un)signed wrap. Since we know that last value didn't exit, nor
12902 // will any smaller one.
12903 return canAssumeNoSelfWrap(IV);
12904 };
12905
12906 // Avoid proven overflow cases: this will ensure that the backedge taken
12907 // count will not generate any unsigned overflow. Relaxed no-overflow
12908 // conditions exploit NoWrapFlags, allowing to optimize in presence of
12909 // undefined behaviors like the case of C language.
12910 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
12911 return getCouldNotCompute();
12912 }
12913
12914 // On all paths just preceeding, we established the following invariant:
12915 // IV can be assumed not to overflow up to and including the exiting
12916 // iteration. We proved this in one of two ways:
12917 // 1) We can show overflow doesn't occur before the exiting iteration
12918 // 1a) canIVOverflowOnLT, and b) step of one
12919 // 2) We can show that if overflow occurs, the loop must execute UB
12920 // before any possible exit.
12921 // Note that we have not yet proved RHS invariant (in general).
12922
12923 const SCEV *Start = IV->getStart();
12924
12925 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
12926 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
12927 // Use integer-typed versions for actual computation; we can't subtract
12928 // pointers in general.
12929 const SCEV *OrigStart = Start;
12930 const SCEV *OrigRHS = RHS;
12931 if (Start->getType()->isPointerTy()) {
12932 Start = getLosslessPtrToIntExpr(Start);
12933 if (isa<SCEVCouldNotCompute>(Start))
12934 return Start;
12935 }
12936 if (RHS->getType()->isPointerTy()) {
12938 if (isa<SCEVCouldNotCompute>(RHS))
12939 return RHS;
12940 }
12941
12942 // When the RHS is not invariant, we do not know the end bound of the loop and
12943 // cannot calculate the ExactBECount needed by ExitLimit. However, we can
12944 // calculate the MaxBECount, given the start, stride and max value for the end
12945 // bound of the loop (RHS), and the fact that IV does not overflow (which is
12946 // checked above).
12947 if (!isLoopInvariant(RHS, L)) {
12948 const SCEV *MaxBECount = computeMaxBECountForLT(
12949 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12950 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
12951 MaxBECount, false /*MaxOrZero*/, Predicates);
12952 }
12953
12954 // We use the expression (max(End,Start)-Start)/Stride to describe the
12955 // backedge count, as if the backedge is taken at least once max(End,Start)
12956 // is End and so the result is as above, and if not max(End,Start) is Start
12957 // so we get a backedge count of zero.
12958 const SCEV *BECount = nullptr;
12959 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
12960 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
12961 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
12962 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
12963 // Can we prove (max(RHS,Start) > Start - Stride?
12964 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
12965 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
12966 // In this case, we can use a refined formula for computing backedge taken
12967 // count. The general formula remains:
12968 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
12969 // We want to use the alternate formula:
12970 // "((End - 1) - (Start - Stride)) /u Stride"
12971 // Let's do a quick case analysis to show these are equivalent under
12972 // our precondition that max(RHS,Start) > Start - Stride.
12973 // * For RHS <= Start, the backedge-taken count must be zero.
12974 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12975 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
12976 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
12977 // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
12978 // this to the stride of 1 case.
12979 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
12980 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12981 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
12982 // "((RHS - (Start - Stride) - 1) /u Stride".
12983 // Our preconditions trivially imply no overflow in that form.
12984 const SCEV *MinusOne = getMinusOne(Stride->getType());
12985 const SCEV *Numerator =
12986 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
12987 BECount = getUDivExpr(Numerator, Stride);
12988 }
12989
12990 const SCEV *BECountIfBackedgeTaken = nullptr;
12991 if (!BECount) {
12992 auto canProveRHSGreaterThanEqualStart = [&]() {
12993 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
12994 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
12995 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
12996
12997 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
12998 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
12999 return true;
13000
13001 // (RHS > Start - 1) implies RHS >= Start.
13002 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13003 // "Start - 1" doesn't overflow.
13004 // * For signed comparison, if Start - 1 does overflow, it's equal
13005 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13006 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13007 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13008 //
13009 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13010 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13011 auto *StartMinusOne = getAddExpr(OrigStart,
13012 getMinusOne(OrigStart->getType()));
13013 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13014 };
13015
13016 // If we know that RHS >= Start in the context of loop, then we know that
13017 // max(RHS, Start) = RHS at this point.
13018 const SCEV *End;
13019 if (canProveRHSGreaterThanEqualStart()) {
13020 End = RHS;
13021 } else {
13022 // If RHS < Start, the backedge will be taken zero times. So in
13023 // general, we can write the backedge-taken count as:
13024 //
13025 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13026 //
13027 // We convert it to the following to make it more convenient for SCEV:
13028 //
13029 // ceil(max(RHS, Start) - Start) / Stride
13030 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13031
13032 // See what would happen if we assume the backedge is taken. This is
13033 // used to compute MaxBECount.
13034 BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13035 }
13036
13037 // At this point, we know:
13038 //
13039 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13040 // 2. The index variable doesn't overflow.
13041 //
13042 // Therefore, we know N exists such that
13043 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13044 // doesn't overflow.
13045 //
13046 // Using this information, try to prove whether the addition in
13047 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13048 const SCEV *One = getOne(Stride->getType());
13049 bool MayAddOverflow = [&] {
13050 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
13051 if (StrideC->getAPInt().isPowerOf2()) {
13052 // Suppose Stride is a power of two, and Start/End are unsigned
13053 // integers. Let UMAX be the largest representable unsigned
13054 // integer.
13055 //
13056 // By the preconditions of this function, we know
13057 // "(Start + Stride * N) >= End", and this doesn't overflow.
13058 // As a formula:
13059 //
13060 // End <= (Start + Stride * N) <= UMAX
13061 //
13062 // Subtracting Start from all the terms:
13063 //
13064 // End - Start <= Stride * N <= UMAX - Start
13065 //
13066 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13067 //
13068 // End - Start <= Stride * N <= UMAX
13069 //
13070 // Stride * N is a multiple of Stride. Therefore,
13071 //
13072 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13073 //
13074 // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
13075 // Therefore, UMAX mod Stride == Stride - 1. So we can write:
13076 //
13077 // End - Start <= Stride * N <= UMAX - Stride - 1
13078 //
13079 // Dropping the middle term:
13080 //
13081 // End - Start <= UMAX - Stride - 1
13082 //
13083 // Adding Stride - 1 to both sides:
13084 //
13085 // (End - Start) + (Stride - 1) <= UMAX
13086 //
13087 // In other words, the addition doesn't have unsigned overflow.
13088 //
13089 // A similar proof works if we treat Start/End as signed values.
13090 // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
13091 // use signed max instead of unsigned max. Note that we're trying
13092 // to prove a lack of unsigned overflow in either case.
13093 return false;
13094 }
13095 }
13096 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13097 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
13098 // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
13099 // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
13100 //
13101 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
13102 return false;
13103 }
13104 return true;
13105 }();
13106
13107 const SCEV *Delta = getMinusSCEV(End, Start);
13108 if (!MayAddOverflow) {
13109 // floor((D + (S - 1)) / S)
13110 // We prefer this formulation if it's legal because it's fewer operations.
13111 BECount =
13112 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13113 } else {
13114 BECount = getUDivCeilSCEV(Delta, Stride);
13115 }
13116 }
13117
13118 const SCEV *ConstantMaxBECount;
13119 bool MaxOrZero = false;
13120 if (isa<SCEVConstant>(BECount)) {
13121 ConstantMaxBECount = BECount;
13122 } else if (BECountIfBackedgeTaken &&
13123 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13124 // If we know exactly how many times the backedge will be taken if it's
13125 // taken at least once, then the backedge count will either be that or
13126 // zero.
13127 ConstantMaxBECount = BECountIfBackedgeTaken;
13128 MaxOrZero = true;
13129 } else {
13130 ConstantMaxBECount = computeMaxBECountForLT(
13131 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13132 }
13133
13134 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13135 !isa<SCEVCouldNotCompute>(BECount))
13136 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13137
13138 const SCEV *SymbolicMaxBECount =
13139 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13140 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13141 Predicates);
13142}
13143
13144ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13145 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13146 bool ControlsOnlyExit, bool AllowPredicates) {
13148 // We handle only IV > Invariant
13149 if (!isLoopInvariant(RHS, L))
13150 return getCouldNotCompute();
13151
13152 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13153 if (!IV && AllowPredicates)
13154 // Try to make this an AddRec using runtime tests, in the first X
13155 // iterations of this loop, where X is the SCEV expression found by the
13156 // algorithm below.
13157 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13158
13159 // Avoid weird loops
13160 if (!IV || IV->getLoop() != L || !IV->isAffine())
13161 return getCouldNotCompute();
13162
13163 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13164 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13166
13167 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13168
13169 // Avoid negative or zero stride values
13170 if (!isKnownPositive(Stride))
13171 return getCouldNotCompute();
13172
13173 // Avoid proven overflow cases: this will ensure that the backedge taken count
13174 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13175 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13176 // behaviors like the case of C language.
13177 if (!Stride->isOne() && !NoWrap)
13178 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13179 return getCouldNotCompute();
13180
13181 const SCEV *Start = IV->getStart();
13182 const SCEV *End = RHS;
13183 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13184 // If we know that Start >= RHS in the context of loop, then we know that
13185 // min(RHS, Start) = RHS at this point.
13187 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13188 End = RHS;
13189 else
13190 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13191 }
13192
13193 if (Start->getType()->isPointerTy()) {
13194 Start = getLosslessPtrToIntExpr(Start);
13195 if (isa<SCEVCouldNotCompute>(Start))
13196 return Start;
13197 }
13198 if (End->getType()->isPointerTy()) {
13200 if (isa<SCEVCouldNotCompute>(End))
13201 return End;
13202 }
13203
13204 // Compute ((Start - End) + (Stride - 1)) / Stride.
13205 // FIXME: This can overflow. Holding off on fixing this for now;
13206 // howManyGreaterThans will hopefully be gone soon.
13207 const SCEV *One = getOne(Stride->getType());
13208 const SCEV *BECount = getUDivExpr(
13209 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13210
13211 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13212 : getUnsignedRangeMax(Start);
13213
13214 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13215 : getUnsignedRangeMin(Stride);
13216
13217 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13218 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13219 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13220
13221 // Although End can be a MIN expression we estimate MinEnd considering only
13222 // the case End = RHS. This is safe because in the other case (Start - End)
13223 // is zero, leading to a zero maximum backedge taken count.
13224 APInt MinEnd =
13225 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13226 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13227
13228 const SCEV *ConstantMaxBECount =
13229 isa<SCEVConstant>(BECount)
13230 ? BECount
13231 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13232 getConstant(MinStride));
13233
13234 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13235 ConstantMaxBECount = BECount;
13236 const SCEV *SymbolicMaxBECount =
13237 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13238
13239 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13240 Predicates);
13241}
13242
13244 ScalarEvolution &SE) const {
13245 if (Range.isFullSet()) // Infinite loop.
13246 return SE.getCouldNotCompute();
13247
13248 // If the start is a non-zero constant, shift the range to simplify things.
13249 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13250 if (!SC->getValue()->isZero()) {
13252 Operands[0] = SE.getZero(SC->getType());
13253 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13254 getNoWrapFlags(FlagNW));
13255 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13256 return ShiftedAddRec->getNumIterationsInRange(
13257 Range.subtract(SC->getAPInt()), SE);
13258 // This is strange and shouldn't happen.
13259 return SE.getCouldNotCompute();
13260 }
13261
13262 // The only time we can solve this is when we have all constant indices.
13263 // Otherwise, we cannot determine the overflow conditions.
13264 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13265 return SE.getCouldNotCompute();
13266
13267 // Okay at this point we know that all elements of the chrec are constants and
13268 // that the start element is zero.
13269
13270 // First check to see if the range contains zero. If not, the first
13271 // iteration exits.
13272 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13273 if (!Range.contains(APInt(BitWidth, 0)))
13274 return SE.getZero(getType());
13275
13276 if (isAffine()) {
13277 // If this is an affine expression then we have this situation:
13278 // Solve {0,+,A} in Range === Ax in Range
13279
13280 // We know that zero is in the range. If A is positive then we know that
13281 // the upper value of the range must be the first possible exit value.
13282 // If A is negative then the lower of the range is the last possible loop
13283 // value. Also note that we already checked for a full range.
13284 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13285 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13286
13287 // The exit value should be (End+A)/A.
13288 APInt ExitVal = (End + A).udiv(A);
13289 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13290
13291 // Evaluate at the exit value. If we really did fall out of the valid
13292 // range, then we computed our trip count, otherwise wrap around or other
13293 // things must have happened.
13294 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13295 if (Range.contains(Val->getValue()))
13296 return SE.getCouldNotCompute(); // Something strange happened
13297
13298 // Ensure that the previous value is in the range.
13299 assert(Range.contains(
13301 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13302 "Linear scev computation is off in a bad way!");
13303 return SE.getConstant(ExitValue);
13304 }
13305
13306 if (isQuadratic()) {
13307 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13308 return SE.getConstant(*S);
13309 }
13310
13311 return SE.getCouldNotCompute();
13312}
13313
13314const SCEVAddRecExpr *
13316 assert(getNumOperands() > 1 && "AddRec with zero step?");
13317 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13318 // but in this case we cannot guarantee that the value returned will be an
13319 // AddRec because SCEV does not have a fixed point where it stops
13320 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13321 // may happen if we reach arithmetic depth limit while simplifying. So we
13322 // construct the returned value explicitly.
13324 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13325 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13326 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13327 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13328 // We know that the last operand is not a constant zero (otherwise it would
13329 // have been popped out earlier). This guarantees us that if the result has
13330 // the same last operand, then it will also not be popped out, meaning that
13331 // the returned value will be an AddRec.
13332 const SCEV *Last = getOperand(getNumOperands() - 1);
13333 assert(!Last->isZero() && "Recurrency with zero step?");
13334 Ops.push_back(Last);
13335 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13337}
13338
13339// Return true when S contains at least an undef value.
13341 return SCEVExprContains(S, [](const SCEV *S) {
13342 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13343 return isa<UndefValue>(SU->getValue());
13344 return false;
13345 });
13346}
13347
13348// Return true when S contains a value that is a nullptr.
13350 return SCEVExprContains(S, [](const SCEV *S) {
13351 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13352 return SU->getValue() == nullptr;
13353 return false;
13354 });
13355}
13356
13357/// Return the size of an element read or written by Inst.
13359 Type *Ty;
13360 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13361 Ty = Store->getValueOperand()->getType();
13362 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13363 Ty = Load->getType();
13364 else
13365 return nullptr;
13366
13368 return getSizeOfExpr(ETy, Ty);
13369}
13370
13371//===----------------------------------------------------------------------===//
13372// SCEVCallbackVH Class Implementation
13373//===----------------------------------------------------------------------===//
13374
13375void ScalarEvolution::SCEVCallbackVH::deleted() {
13376 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13377 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13378 SE->ConstantEvolutionLoopExitValue.erase(PN);
13379 SE->eraseValueFromMap(getValPtr());
13380 // this now dangles!
13381}
13382
13383void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13384 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13385
13386 // Forget all the expressions associated with users of the old value,
13387 // so that future queries will recompute the expressions using the new
13388 // value.
13389 SE->forgetValue(getValPtr());
13390 // this now dangles!
13391}
13392
13393ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13394 : CallbackVH(V), SE(se) {}
13395
13396//===----------------------------------------------------------------------===//
13397// ScalarEvolution Class Implementation
13398//===----------------------------------------------------------------------===//
13399
13402 LoopInfo &LI)
13403 : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
13404 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13405 LoopDispositions(64), BlockDispositions(64) {
13406 // To use guards for proving predicates, we need to scan every instruction in
13407 // relevant basic blocks, and not just terminators. Doing this is a waste of
13408 // time if the IR does not actually contain any calls to
13409 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13410 //
13411 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13412 // to _add_ guards to the module when there weren't any before, and wants
13413 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13414 // efficient in lieu of being smart in that rather obscure case.
13415
13416 auto *GuardDecl = F.getParent()->getFunction(
13417 Intrinsic::getName(Intrinsic::experimental_guard));
13418 HasGuards = GuardDecl && !GuardDecl->use_empty();
13419}
13420
13422 : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
13423 LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13424 ValueExprMap(std::move(Arg.ValueExprMap)),
13425 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13426 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13427 PendingMerges(std::move(Arg.PendingMerges)),
13428 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13429 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13430 PredicatedBackedgeTakenCounts(
13431 std::move(Arg.PredicatedBackedgeTakenCounts)),
13432 BECountUsers(std::move(Arg.BECountUsers)),
13433 ConstantEvolutionLoopExitValue(
13434 std::move(Arg.ConstantEvolutionLoopExitValue)),
13435 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13436 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13437 LoopDispositions(std::move(Arg.LoopDispositions)),
13438 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13439 BlockDispositions(std::move(Arg.BlockDispositions)),
13440 SCEVUsers(std::move(Arg.SCEVUsers)),
13441 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13442 SignedRanges(std::move(Arg.SignedRanges)),
13443 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13444 UniquePreds(std::move(Arg.UniquePreds)),
13445 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13446 LoopUsers(std::move(Arg.LoopUsers)),
13447 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13448 FirstUnknown(Arg.FirstUnknown) {
13449 Arg.FirstUnknown = nullptr;
13450}
13451
13453 // Iterate through all the SCEVUnknown instances and call their
13454 // destructors, so that they release their references to their values.
13455 for (SCEVUnknown *U = FirstUnknown; U;) {
13456 SCEVUnknown *Tmp = U;
13457 U = U->Next;
13458 Tmp->~SCEVUnknown();
13459 }
13460 FirstUnknown = nullptr;
13461
13462 ExprValueMap.clear();
13463 ValueExprMap.clear();
13464 HasRecMap.clear();
13465 BackedgeTakenCounts.clear();
13466 PredicatedBackedgeTakenCounts.clear();
13467
13468 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13469 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13470 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13471 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13472 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13473}
13474
13476 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13477}
13478
13479/// When printing a top-level SCEV for trip counts, it's helpful to include
13480/// a type for constants which are otherwise hard to disambiguate.
13481static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13482 if (isa<SCEVConstant>(S))
13483 OS << *S->getType() << " ";
13484 OS << *S;
13485}
13486
13488 const Loop *L) {
13489 // Print all inner loops first
13490 for (Loop *I : *L)
13491 PrintLoopInfo(OS, SE, I);
13492
13493 OS << "Loop ";
13494 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13495 OS << ": ";
13496
13497 SmallVector<BasicBlock *, 8> ExitingBlocks;
13498 L->getExitingBlocks(ExitingBlocks);
13499 if (ExitingBlocks.size() != 1)
13500 OS << "<multiple exits> ";
13501
13502 auto *BTC = SE->getBackedgeTakenCount(L);
13503 if (!isa<SCEVCouldNotCompute>(BTC)) {
13504 OS << "backedge-taken count is ";
13506 } else
13507 OS << "Unpredictable backedge-taken count.";
13508 OS << "\n";
13509
13510 if (ExitingBlocks.size() > 1)
13511 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13512 OS << " exit count for " << ExitingBlock->getName() << ": ";
13513 PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
13514 OS << "\n";
13515 }
13516
13517 OS << "Loop ";
13518 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13519 OS << ": ";
13520
13521 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13522 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13523 OS << "constant max backedge-taken count is ";
13524 PrintSCEVWithTypeHint(OS, ConstantBTC);
13526 OS << ", actual taken count either this or zero.";
13527 } else {
13528 OS << "Unpredictable constant max backedge-taken count. ";
13529 }
13530
13531 OS << "\n"
13532 "Loop ";
13533 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13534 OS << ": ";
13535
13536 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13537 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13538 OS << "symbolic max backedge-taken count is ";
13539 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13541 OS << ", actual taken count either this or zero.";
13542 } else {
13543 OS << "Unpredictable symbolic max backedge-taken count. ";
13544 }
13545 OS << "\n";
13546
13547 if (ExitingBlocks.size() > 1)
13548 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13549 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13550 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13552 PrintSCEVWithTypeHint(OS, ExitBTC);
13553 OS << "\n";
13554 }
13555
13557 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13558 if (PBT != BTC || !Preds.empty()) {
13559 OS << "Loop ";
13560 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13561 OS << ": ";
13562 if (!isa<SCEVCouldNotCompute>(PBT)) {
13563 OS << "Predicated backedge-taken count is ";
13565 } else
13566 OS << "Unpredictable predicated backedge-taken count.";
13567 OS << "\n";
13568 OS << " Predicates:\n";
13569 for (const auto *P : Preds)
13570 P->print(OS, 4);
13571 }
13572
13574 OS << "Loop ";
13575 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13576 OS << ": ";
13577 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13578 }
13579}
13580
13581namespace llvm {
13583 switch (LD) {
13585 OS << "Variant";
13586 break;
13588 OS << "Invariant";
13589 break;
13591 OS << "Computable";
13592 break;
13593 }
13594 return OS;
13595}
13596
13598 switch (BD) {
13600 OS << "DoesNotDominate";
13601 break;
13603 OS << "Dominates";
13604 break;
13606 OS << "ProperlyDominates";
13607 break;
13608 }
13609 return OS;
13610}
13611}
13612
13614 // ScalarEvolution's implementation of the print method is to print
13615 // out SCEV values of all instructions that are interesting. Doing
13616 // this potentially causes it to create new SCEV objects though,
13617 // which technically conflicts with the const qualifier. This isn't
13618 // observable from outside the class though, so casting away the
13619 // const isn't dangerous.
13620 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13621
13622 if (ClassifyExpressions) {
13623 OS << "Classifying expressions for: ";
13624 F.printAsOperand(OS, /*PrintType=*/false);
13625 OS << "\n";
13626 for (Instruction &I : instructions(F))
13627 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13628 OS << I << '\n';
13629 OS << " --> ";
13630 const SCEV *SV = SE.getSCEV(&I);
13631 SV->print(OS);
13632 if (!isa<SCEVCouldNotCompute>(SV)) {
13633 OS << " U: ";
13634 SE.getUnsignedRange(SV).print(OS);
13635 OS << " S: ";
13636 SE.getSignedRange(SV).print(OS);
13637 }
13638
13639 const Loop *L = LI.getLoopFor(I.getParent());
13640
13641 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13642 if (AtUse != SV) {
13643 OS << " --> ";
13644 AtUse->print(OS);
13645 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13646 OS << " U: ";
13647 SE.getUnsignedRange(AtUse).print(OS);
13648 OS << " S: ";
13649 SE.getSignedRange(AtUse).print(OS);
13650 }
13651 }
13652
13653 if (L) {
13654 OS << "\t\t" "Exits: ";
13655 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13656 if (!SE.isLoopInvariant(ExitValue, L)) {
13657 OS << "<<Unknown>>";
13658 } else {
13659 OS << *ExitValue;
13660 }
13661
13662 bool First = true;
13663 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13664 if (First) {
13665 OS << "\t\t" "LoopDispositions: { ";
13666 First = false;
13667 } else {
13668 OS << ", ";
13669 }
13670
13671 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13672 OS << ": " << SE.getLoopDisposition(SV, Iter);
13673 }
13674
13675 for (const auto *InnerL : depth_first(L)) {
13676 if (InnerL == L)
13677 continue;
13678 if (First) {
13679 OS << "\t\t" "LoopDispositions: { ";
13680 First = false;
13681 } else {
13682 OS << ", ";
13683 }
13684
13685 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13686 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13687 }
13688
13689 OS << " }";
13690 }
13691
13692 OS << "\n";
13693 }
13694 }
13695
13696 OS << "Determining loop execution counts for: ";
13697 F.printAsOperand(OS, /*PrintType=*/false);
13698 OS << "\n";
13699 for (Loop *I : LI)
13700 PrintLoopInfo(OS, &SE, I);
13701}
13702
13705 auto &Values = LoopDispositions[S];
13706 for (auto &V : Values) {
13707 if (V.getPointer() == L)
13708 return V.getInt();
13709 }
13710 Values.emplace_back(L, LoopVariant);
13711 LoopDisposition D = computeLoopDisposition(S, L);
13712 auto &Values2 = LoopDispositions[S];
13713 for (auto &V : llvm::reverse(Values2)) {
13714 if (V.getPointer() == L) {
13715 V.setInt(D);
13716 break;
13717 }
13718 }
13719 return D;
13720}
13721
13723ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
13724 switch (S->getSCEVType()) {
13725 case scConstant:
13726 case scVScale:
13727 return LoopInvariant;
13728 case scAddRecExpr: {
13729 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13730
13731 // If L is the addrec's loop, it's computable.
13732 if (AR->getLoop() == L)
13733 return LoopComputable;
13734
13735 // Add recurrences are never invariant in the function-body (null loop).
13736 if (!L)
13737 return LoopVariant;
13738
13739 // Everything that is not defined at loop entry is variant.
13740 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
13741 return LoopVariant;
13742 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
13743 " dominate the contained loop's header?");
13744
13745 // This recurrence is invariant w.r.t. L if AR's loop contains L.
13746 if (AR->getLoop()->contains(L))
13747 return LoopInvariant;
13748
13749 // This recurrence is variant w.r.t. L if any of its operands
13750 // are variant.
13751 for (const auto *Op : AR->operands())
13752 if (!isLoopInvariant(Op, L))
13753 return LoopVariant;
13754
13755 // Otherwise it's loop-invariant.
13756 return LoopInvariant;
13757 }
13758 case scTruncate:
13759 case scZeroExtend:
13760 case scSignExtend:
13761 case scPtrToInt:
13762 case scAddExpr:
13763 case scMulExpr:
13764 case scUDivExpr:
13765 case scUMaxExpr:
13766 case scSMaxExpr:
13767 case scUMinExpr:
13768 case scSMinExpr:
13769 case scSequentialUMinExpr: {
13770 bool HasVarying = false;
13771 for (const auto *Op : S->operands()) {
13773 if (D == LoopVariant)
13774 return LoopVariant;
13775 if (D == LoopComputable)
13776 HasVarying = true;
13777 }
13778 return HasVarying ? LoopComputable : LoopInvariant;
13779 }
13780 case scUnknown:
13781 // All non-instruction values are loop invariant. All instructions are loop
13782 // invariant if they are not contained in the specified loop.
13783 // Instructions are never considered invariant in the function body
13784 // (null loop) because they are defined within the "loop".
13785 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13786 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13787 return LoopInvariant;
13788 case scCouldNotCompute:
13789 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13790 }
13791 llvm_unreachable("Unknown SCEV kind!");
13792}
13793
13795 return getLoopDisposition(S, L) == LoopInvariant;
13796}
13797
13799 return getLoopDisposition(S, L) == LoopComputable;
13800}
13801
13804 auto &Values = BlockDispositions[S];
13805 for (auto &V : Values) {
13806 if (V.getPointer() == BB)
13807 return V.getInt();
13808 }
13809 Values.emplace_back(BB, DoesNotDominateBlock);
13810 BlockDisposition D = computeBlockDisposition(S, BB);
13811 auto &Values2 = BlockDispositions[S];
13812 for (auto &V : llvm::reverse(Values2)) {
13813 if (V.getPointer() == BB) {
13814 V.setInt(D);
13815 break;
13816 }
13817 }
13818 return D;
13819}
13820
13822ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13823 switch (S->getSCEVType()) {
13824 case scConstant:
13825 case scVScale:
13827 case scAddRecExpr: {
13828 // This uses a "dominates" query instead of "properly dominates" query
13829 // to test for proper dominance too, because the instruction which
13830 // produces the addrec's value is a PHI, and a PHI effectively properly
13831 // dominates its entire containing block.
13832 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13833 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13834 return DoesNotDominateBlock;
13835
13836 // Fall through into SCEVNAryExpr handling.
13837 [[fallthrough]];
13838 }
13839 case scTruncate:
13840 case scZeroExtend:
13841 case scSignExtend:
13842 case scPtrToInt:
13843 case scAddExpr:
13844 case scMulExpr:
13845 case scUDivExpr:
13846 case scUMaxExpr:
13847 case scSMaxExpr:
13848 case scUMinExpr:
13849 case scSMinExpr:
13850 case scSequentialUMinExpr: {
13851 bool Proper = true;
13852 for (const SCEV *NAryOp : S->operands()) {
13854 if (D == DoesNotDominateBlock)
13855 return DoesNotDominateBlock;
13856 if (D == DominatesBlock)
13857 Proper = false;
13858 }
13859 return Proper ? ProperlyDominatesBlock : DominatesBlock;
13860 }
13861 case scUnknown:
13862 if (Instruction *I =
13863 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
13864 if (I->getParent() == BB)
13865 return DominatesBlock;
13866 if (DT.properlyDominates(I->getParent(), BB))
13868 return DoesNotDominateBlock;
13869 }
13871 case scCouldNotCompute:
13872 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13873 }
13874 llvm_unreachable("Unknown SCEV kind!");
13875}
13876
13877bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
13878 return getBlockDisposition(S, BB) >= DominatesBlock;
13879}
13880
13883}
13884
13885bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
13886 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
13887}
13888
13889void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
13890 bool Predicated) {
13891 auto &BECounts =
13892 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13893 auto It = BECounts.find(L);
13894 if (It != BECounts.end()) {
13895 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
13896 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
13897 if (!isa<SCEVConstant>(S)) {
13898 auto UserIt = BECountUsers.find(S);
13899 assert(UserIt != BECountUsers.end());
13900 UserIt->second.erase({L, Predicated});
13901 }
13902 }
13903 }
13904 BECounts.erase(It);
13905 }
13906}
13907
13908void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
13909 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
13910 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
13911
13912 while (!Worklist.empty()) {
13913 const SCEV *Curr = Worklist.pop_back_val();
13914 auto Users = SCEVUsers.find(Curr);
13915 if (Users != SCEVUsers.end())
13916 for (const auto *User : Users->second)
13917 if (ToForget.insert(User).second)
13918 Worklist.push_back(User);
13919 }
13920
13921 for (const auto *S : ToForget)
13922 forgetMemoizedResultsImpl(S);
13923
13924 for (auto I = PredicatedSCEVRewrites.begin();
13925 I != PredicatedSCEVRewrites.end();) {
13926 std::pair<const SCEV *, const Loop *> Entry = I->first;
13927 if (ToForget.count(Entry.first))
13928 PredicatedSCEVRewrites.erase(I++);
13929 else
13930 ++I;
13931 }
13932}
13933
13934void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
13935 LoopDispositions.erase(S);
13936 BlockDispositions.erase(S);
13937 UnsignedRanges.erase(S);
13938 SignedRanges.erase(S);
13939 HasRecMap.erase(S);
13940 ConstantMultipleCache.erase(S);
13941
13942 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
13943 UnsignedWrapViaInductionTried.erase(AR);
13944 SignedWrapViaInductionTried.erase(AR);
13945 }
13946
13947 auto ExprIt = ExprValueMap.find(S);
13948 if (ExprIt != ExprValueMap.end()) {
13949 for (Value *V : ExprIt->second) {
13950 auto ValueIt = ValueExprMap.find_as(V);
13951 if (ValueIt != ValueExprMap.end())
13952 ValueExprMap.erase(ValueIt);
13953 }
13954 ExprValueMap.erase(ExprIt);
13955 }
13956
13957 auto ScopeIt = ValuesAtScopes.find(S);
13958 if (ScopeIt != ValuesAtScopes.end()) {
13959 for (const auto &Pair : ScopeIt->second)
13960 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
13961 llvm::erase(ValuesAtScopesUsers[Pair.second],
13962 std::make_pair(Pair.first, S));
13963 ValuesAtScopes.erase(ScopeIt);
13964 }
13965
13966 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
13967 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
13968 for (const auto &Pair : ScopeUserIt->second)
13969 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
13970 ValuesAtScopesUsers.erase(ScopeUserIt);
13971 }
13972
13973 auto BEUsersIt = BECountUsers.find(S);
13974 if (BEUsersIt != BECountUsers.end()) {
13975 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
13976 auto Copy = BEUsersIt->second;
13977 for (const auto &Pair : Copy)
13978 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
13979 BECountUsers.erase(BEUsersIt);
13980 }
13981
13982 auto FoldUser = FoldCacheUser.find(S);
13983 if (FoldUser != FoldCacheUser.end())
13984 for (auto &KV : FoldUser->second)
13985 FoldCache.erase(KV);
13986 FoldCacheUser.erase(S);
13987}
13988
13989void
13990ScalarEvolution::getUsedLoops(const SCEV *S,
13991 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
13992 struct FindUsedLoops {
13993 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
13994 : LoopsUsed(LoopsUsed) {}
13996 bool follow(const SCEV *S) {
13997 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
13998 LoopsUsed.insert(AR->getLoop());
13999 return true;
14000 }
14001
14002 bool isDone() const { return false; }
14003 };
14004
14005 FindUsedLoops F(LoopsUsed);
14007}
14008
14009void ScalarEvolution::getReachableBlocks(
14012 Worklist.push_back(&F.getEntryBlock());
14013 while (!Worklist.empty()) {
14014 BasicBlock *BB = Worklist.pop_back_val();
14015 if (!Reachable.insert(BB).second)
14016 continue;
14017
14018 Value *Cond;
14019 BasicBlock *TrueBB, *FalseBB;
14020 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14021 m_BasicBlock(FalseBB)))) {
14022 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14023 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14024 continue;
14025 }
14026
14027 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14028 const SCEV *L = getSCEV(Cmp->getOperand(0));
14029 const SCEV *R = getSCEV(Cmp->getOperand(1));
14030 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
14031 Worklist.push_back(TrueBB);
14032 continue;
14033 }
14034 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
14035 R)) {
14036 Worklist.push_back(FalseBB);
14037 continue;
14038 }
14039 }
14040 }
14041
14042 append_range(Worklist, successors(BB));
14043 }
14044}
14045
14047 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14048 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14049
14050 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14051
14052 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14053 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14054 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14055
14056 const SCEV *visitConstant(const SCEVConstant *Constant) {
14057 return SE.getConstant(Constant->getAPInt());
14058 }
14059
14060 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14061 return SE.getUnknown(Expr->getValue());
14062 }
14063
14064 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14065 return SE.getCouldNotCompute();
14066 }
14067 };
14068
14069 SCEVMapper SCM(SE2);
14070 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14071 SE2.getReachableBlocks(ReachableBlocks, F);
14072
14073 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14074 if (containsUndefs(Old) || containsUndefs(New)) {
14075 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14076 // not propagate undef aggressively). This means we can (and do) fail
14077 // verification in cases where a transform makes a value go from "undef"
14078 // to "undef+1" (say). The transform is fine, since in both cases the
14079 // result is "undef", but SCEV thinks the value increased by 1.
14080 return nullptr;
14081 }
14082
14083 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14084 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14085 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14086 return nullptr;
14087
14088 return Delta;
14089 };
14090
14091 while (!LoopStack.empty()) {
14092 auto *L = LoopStack.pop_back_val();
14093 llvm::append_range(LoopStack, *L);
14094
14095 // Only verify BECounts in reachable loops. For an unreachable loop,
14096 // any BECount is legal.
14097 if (!ReachableBlocks.contains(L->getHeader()))
14098 continue;
14099
14100 // Only verify cached BECounts. Computing new BECounts may change the
14101 // results of subsequent SCEV uses.
14102 auto It = BackedgeTakenCounts.find(L);
14103 if (It == BackedgeTakenCounts.end())
14104 continue;
14105
14106 auto *CurBECount =
14107 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14108 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14109
14110 if (CurBECount == SE2.getCouldNotCompute() ||
14111 NewBECount == SE2.getCouldNotCompute()) {
14112 // NB! This situation is legal, but is very suspicious -- whatever pass
14113 // change the loop to make a trip count go from could not compute to
14114 // computable or vice-versa *should have* invalidated SCEV. However, we
14115 // choose not to assert here (for now) since we don't want false
14116 // positives.
14117 continue;
14118 }
14119
14120 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14121 SE.getTypeSizeInBits(NewBECount->getType()))
14122 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14123 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14124 SE.getTypeSizeInBits(NewBECount->getType()))
14125 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14126
14127 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14128 if (Delta && !Delta->isZero()) {
14129 dbgs() << "Trip Count for " << *L << " Changed!\n";
14130 dbgs() << "Old: " << *CurBECount << "\n";
14131 dbgs() << "New: " << *NewBECount << "\n";
14132 dbgs() << "Delta: " << *Delta << "\n";
14133 std::abort();
14134 }
14135 }
14136
14137 // Collect all valid loops currently in LoopInfo.
14138 SmallPtrSet<Loop *, 32> ValidLoops;
14139 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14140 while (!Worklist.empty()) {
14141 Loop *L = Worklist.pop_back_val();
14142 if (ValidLoops.insert(L).second)
14143 Worklist.append(L->begin(), L->end());
14144 }
14145 for (const auto &KV : ValueExprMap) {
14146#ifndef NDEBUG
14147 // Check for SCEV expressions referencing invalid/deleted loops.
14148 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14149 assert(ValidLoops.contains(AR->getLoop()) &&
14150 "AddRec references invalid loop");
14151 }
14152#endif
14153
14154 // Check that the value is also part of the reverse map.
14155 auto It = ExprValueMap.find(KV.second);
14156 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14157 dbgs() << "Value " << *KV.first
14158 << " is in ValueExprMap but not in ExprValueMap\n";
14159 std::abort();
14160 }
14161
14162 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14163 if (!ReachableBlocks.contains(I->getParent()))
14164 continue;
14165 const SCEV *OldSCEV = SCM.visit(KV.second);
14166 const SCEV *NewSCEV = SE2.getSCEV(I);
14167 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14168 if (Delta && !Delta->isZero()) {
14169 dbgs() << "SCEV for value " << *I << " changed!\n"
14170 << "Old: " << *OldSCEV << "\n"
14171 << "New: " << *NewSCEV << "\n"
14172 << "Delta: " << *Delta << "\n";
14173 std::abort();
14174 }
14175 }
14176 }
14177
14178 for (const auto &KV : ExprValueMap) {
14179 for (Value *V : KV.second) {
14180 auto It = ValueExprMap.find_as(V);
14181 if (It == ValueExprMap.end()) {
14182 dbgs() << "Value " << *V
14183 << " is in ExprValueMap but not in ValueExprMap\n";
14184 std::abort();
14185 }
14186 if (It->second != KV.first) {
14187 dbgs() << "Value " << *V << " mapped to " << *It->second
14188 << " rather than " << *KV.first << "\n";
14189 std::abort();
14190 }
14191 }
14192 }
14193
14194 // Verify integrity of SCEV users.
14195 for (const auto &S : UniqueSCEVs) {
14196 for (const auto *Op : S.operands()) {
14197 // We do not store dependencies of constants.
14198 if (isa<SCEVConstant>(Op))
14199 continue;
14200 auto It = SCEVUsers.find(Op);
14201 if (It != SCEVUsers.end() && It->second.count(&S))
14202 continue;
14203 dbgs() << "Use of operand " << *Op << " by user " << S
14204 << " is not being tracked!\n";
14205 std::abort();
14206 }
14207 }
14208
14209 // Verify integrity of ValuesAtScopes users.
14210 for (const auto &ValueAndVec : ValuesAtScopes) {
14211 const SCEV *Value = ValueAndVec.first;
14212 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14213 const Loop *L = LoopAndValueAtScope.first;
14214 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14215 if (!isa<SCEVConstant>(ValueAtScope)) {
14216 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14217 if (It != ValuesAtScopesUsers.end() &&
14218 is_contained(It->second, std::make_pair(L, Value)))
14219 continue;
14220 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14221 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14222 std::abort();
14223 }
14224 }
14225 }
14226
14227 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14228 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14229 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14230 const Loop *L = LoopAndValue.first;
14231 const SCEV *Value = LoopAndValue.second;
14232 assert(!isa<SCEVConstant>(Value));
14233 auto It = ValuesAtScopes.find(Value);
14234 if (It != ValuesAtScopes.end() &&
14235 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14236 continue;
14237 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14238 << *ValueAtScope << " missing in ValuesAtScopes\n";
14239 std::abort();
14240 }
14241 }
14242
14243 // Verify integrity of BECountUsers.
14244 auto VerifyBECountUsers = [&](bool Predicated) {
14245 auto &BECounts =
14246 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14247 for (const auto &LoopAndBEInfo : BECounts) {
14248 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14249 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14250 if (!isa<SCEVConstant>(S)) {
14251 auto UserIt = BECountUsers.find(S);
14252 if (UserIt != BECountUsers.end() &&
14253 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14254 continue;
14255 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14256 << " missing from BECountUsers\n";
14257 std::abort();
14258 }
14259 }
14260 }
14261 }
14262 };
14263 VerifyBECountUsers(/* Predicated */ false);
14264 VerifyBECountUsers(/* Predicated */ true);
14265
14266 // Verify intergity of loop disposition cache.
14267 for (auto &[S, Values] : LoopDispositions) {
14268 for (auto [Loop, CachedDisposition] : Values) {
14269 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14270 if (CachedDisposition != RecomputedDisposition) {
14271 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14272 << " is incorrect: cached " << CachedDisposition << ", actual "
14273 << RecomputedDisposition << "\n";
14274 std::abort();
14275 }
14276 }
14277 }
14278
14279 // Verify integrity of the block disposition cache.
14280 for (auto &[S, Values] : BlockDispositions) {
14281 for (auto [BB, CachedDisposition] : Values) {
14282 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14283 if (CachedDisposition != RecomputedDisposition) {
14284 dbgs() << "Cached disposition of " << *S << " for block %"
14285 << BB->getName() << " is incorrect: cached " << CachedDisposition
14286 << ", actual " << RecomputedDisposition << "\n";
14287 std::abort();
14288 }
14289 }
14290 }
14291
14292 // Verify FoldCache/FoldCacheUser caches.
14293 for (auto [FoldID, Expr] : FoldCache) {
14294 auto I = FoldCacheUser.find(Expr);
14295 if (I == FoldCacheUser.end()) {
14296 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14297 << "!\n";
14298 std::abort();
14299 }
14300 if (!is_contained(I->second, FoldID)) {
14301 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14302 std::abort();
14303 }
14304 }
14305 for (auto [Expr, IDs] : FoldCacheUser) {
14306 for (auto &FoldID : IDs) {
14307 auto I = FoldCache.find(FoldID);
14308 if (I == FoldCache.end()) {
14309 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14310 << "!\n";
14311 std::abort();
14312 }
14313 if (I->second != Expr) {
14314 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14315 << *I->second << " != " << *Expr << "!\n";
14316 std::abort();
14317 }
14318 }
14319 }
14320
14321 // Verify that ConstantMultipleCache computations are correct. We check that
14322 // cached multiples and recomputed multiples are multiples of each other to
14323 // verify correctness. It is possible that a recomputed multiple is different
14324 // from the cached multiple due to strengthened no wrap flags or changes in
14325 // KnownBits computations.
14326 for (auto [S, Multiple] : ConstantMultipleCache) {
14327 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14328 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14329 Multiple.urem(RecomputedMultiple) != 0 &&
14330 RecomputedMultiple.urem(Multiple) != 0)) {
14331 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14332 << *S << " : Computed " << RecomputedMultiple
14333 << " but cache contains " << Multiple << "!\n";
14334 std::abort();
14335 }
14336 }
14337}
14338
14340 Function &F, const PreservedAnalyses &PA,
14342 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14343 // of its dependencies is invalidated.
14344 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14345 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14346 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14348 Inv.invalidate<LoopAnalysis>(F, PA);
14349}
14350
14351AnalysisKey ScalarEvolutionAnalysis::Key;
14352
14355 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14356 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14357 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14358 auto &LI = AM.getResult<LoopAnalysis>(F);
14359 return ScalarEvolution(F, TLI, AC, DT, LI);
14360}
14361
14365 return PreservedAnalyses::all();
14366}
14367
14370 // For compatibility with opt's -analyze feature under legacy pass manager
14371 // which was not ported to NPM. This keeps tests using
14372 // update_analyze_test_checks.py working.
14373 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14374 << F.getName() << "':\n";
14376 return PreservedAnalyses::all();
14377}
14378
14380 "Scalar Evolution Analysis", false, true)
14386 "Scalar Evolution Analysis", false, true)
14387
14389
14392}
14393
14395 SE.reset(new ScalarEvolution(
14396 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14397 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14398 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14399 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14400 return false;
14401}
14402
14404
14406 SE->print(OS);
14407}
14408
14410 if (!VerifySCEV)
14411 return;
14412
14413 SE->verify();
14414}
14415
14417 AU.setPreservesAll();
14422}
14423
14425 const SCEV *RHS) {
14427}
14428
14429const SCEVPredicate *
14431 const SCEV *LHS, const SCEV *RHS) {
14433 assert(LHS->getType() == RHS->getType() &&
14434 "Type mismatch between LHS and RHS");
14435 // Unique this node based on the arguments
14436 ID.AddInteger(SCEVPredicate::P_Compare);
14437 ID.AddInteger(Pred);
14438 ID.AddPointer(LHS);
14439 ID.AddPointer(RHS);
14440 void *IP = nullptr;
14441 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14442 return S;
14443 SCEVComparePredicate *Eq = new (SCEVAllocator)
14444 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14445 UniquePreds.InsertNode(Eq, IP);
14446 return Eq;
14447}
14448
14450 const SCEVAddRecExpr *AR,
14453 // Unique this node based on the arguments
14454 ID.AddInteger(SCEVPredicate::P_Wrap);
14455 ID.AddPointer(AR);
14456 ID.AddInteger(AddedFlags);
14457 void *IP = nullptr;
14458 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14459 return S;
14460 auto *OF = new (SCEVAllocator)
14461 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14462 UniquePreds.InsertNode(OF, IP);
14463 return OF;
14464}
14465
14466namespace {
14467
14468class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14469public:
14470
14471 /// Rewrites \p S in the context of a loop L and the SCEV predication
14472 /// infrastructure.
14473 ///
14474 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14475 /// equivalences present in \p Pred.
14476 ///
14477 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14478 /// \p NewPreds such that the result will be an AddRecExpr.
14479 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14481 const SCEVPredicate *Pred) {
14482 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14483 return Rewriter.visit(S);
14484 }
14485
14486 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14487 if (Pred) {
14488 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14489 for (const auto *Pred : U->getPredicates())
14490 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14491 if (IPred->getLHS() == Expr &&
14492 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14493 return IPred->getRHS();
14494 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14495 if (IPred->getLHS() == Expr &&
14496 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14497 return IPred->getRHS();
14498 }
14499 }
14500 return convertToAddRecWithPreds(Expr);
14501 }
14502
14503 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14504 const SCEV *Operand = visit(Expr->getOperand());
14505 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14506 if (AR && AR->getLoop() == L && AR->isAffine()) {
14507 // This couldn't be folded because the operand didn't have the nuw
14508 // flag. Add the nusw flag as an assumption that we could make.
14509 const SCEV *Step = AR->getStepRecurrence(SE);
14510 Type *Ty = Expr->getType();
14511 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14512 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14513 SE.getSignExtendExpr(Step, Ty), L,
14514 AR->getNoWrapFlags());
14515 }
14516 return SE.getZeroExtendExpr(Operand, Expr->getType());
14517 }
14518
14519 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14520 const SCEV *Operand = visit(Expr->getOperand());
14521 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14522 if (AR && AR->getLoop() == L && AR->isAffine()) {
14523 // This couldn't be folded because the operand didn't have the nsw
14524 // flag. Add the nssw flag as an assumption that we could make.
14525 const SCEV *Step = AR->getStepRecurrence(SE);
14526 Type *Ty = Expr->getType();
14527 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14528 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14529 SE.getSignExtendExpr(Step, Ty), L,
14530 AR->getNoWrapFlags());
14531 }
14532 return SE.getSignExtendExpr(Operand, Expr->getType());
14533 }
14534
14535private:
14536 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
14538 const SCEVPredicate *Pred)
14539 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14540
14541 bool addOverflowAssumption(const SCEVPredicate *P) {
14542 if (!NewPreds) {
14543 // Check if we've already made this assumption.
14544 return Pred && Pred->implies(P);
14545 }
14546 NewPreds->insert(P);
14547 return true;
14548 }
14549
14550 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14552 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14553 return addOverflowAssumption(A);
14554 }
14555
14556 // If \p Expr represents a PHINode, we try to see if it can be represented
14557 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14558 // to add this predicate as a runtime overflow check, we return the AddRec.
14559 // If \p Expr does not meet these conditions (is not a PHI node, or we
14560 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14561 // return \p Expr.
14562 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14563 if (!isa<PHINode>(Expr->getValue()))
14564 return Expr;
14565 std::optional<
14566 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14567 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14568 if (!PredicatedRewrite)
14569 return Expr;
14570 for (const auto *P : PredicatedRewrite->second){
14571 // Wrap predicates from outer loops are not supported.
14572 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14573 if (L != WP->getExpr()->getLoop())
14574 return Expr;
14575 }
14576 if (!addOverflowAssumption(P))
14577 return Expr;
14578 }
14579 return PredicatedRewrite->first;
14580 }
14581
14583 const SCEVPredicate *Pred;
14584 const Loop *L;
14585};
14586
14587} // end anonymous namespace
14588
14589const SCEV *
14591 const SCEVPredicate &Preds) {
14592 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14593}
14594
14596 const SCEV *S, const Loop *L,
14599 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14600 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14601
14602 if (!AddRec)
14603 return nullptr;
14604
14605 // Since the transformation was successful, we can now transfer the SCEV
14606 // predicates.
14607 for (const auto *P : TransformPreds)
14608 Preds.insert(P);
14609
14610 return AddRec;
14611}
14612
14613/// SCEV predicates
14615 SCEVPredicateKind Kind)
14616 : FastID(ID), Kind(Kind) {}
14617
14619 const ICmpInst::Predicate Pred,
14620 const SCEV *LHS, const SCEV *RHS)
14621 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14622 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14623 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14624}
14625
14627 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14628
14629 if (!Op)
14630 return false;
14631
14632 if (Pred != ICmpInst::ICMP_EQ)
14633 return false;
14634
14635 return Op->LHS == LHS && Op->RHS == RHS;
14636}
14637
14638bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14639
14641 if (Pred == ICmpInst::ICMP_EQ)
14642 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14643 else
14644 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14645 << *RHS << "\n";
14646
14647}
14648
14650 const SCEVAddRecExpr *AR,
14651 IncrementWrapFlags Flags)
14652 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14653
14654const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14655
14657 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14658
14659 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14660}
14661
14663 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
14664 IncrementWrapFlags IFlags = Flags;
14665
14666 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
14667 IFlags = clearFlags(IFlags, IncrementNSSW);
14668
14669 return IFlags == IncrementAnyWrap;
14670}
14671
14673 OS.indent(Depth) << *getExpr() << " Added Flags: ";
14675 OS << "<nusw>";
14677 OS << "<nssw>";
14678 OS << "\n";
14679}
14680
14683 ScalarEvolution &SE) {
14684 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
14685 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
14686
14687 // We can safely transfer the NSW flag as NSSW.
14688 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
14689 ImpliedFlags = IncrementNSSW;
14690
14691 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
14692 // If the increment is positive, the SCEV NUW flag will also imply the
14693 // WrapPredicate NUSW flag.
14694 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
14695 if (Step->getValue()->getValue().isNonNegative())
14696 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
14697 }
14698
14699 return ImpliedFlags;
14700}
14701
14702/// Union predicates don't get cached so create a dummy set ID for it.
14704 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14705 for (const auto *P : Preds)
14706 add(P);
14707}
14708
14710 return all_of(Preds,
14711 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14712}
14713
14715 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14716 return all_of(Set->Preds,
14717 [this](const SCEVPredicate *I) { return this->implies(I); });
14718
14719 return any_of(Preds,
14720 [N](const SCEVPredicate *I) { return I->implies(N); });
14721}
14722
14724 for (const auto *Pred : Preds)
14725 Pred->print(OS, Depth);
14726}
14727
14728void SCEVUnionPredicate::add(const SCEVPredicate *N) {
14729 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
14730 for (const auto *Pred : Set->Preds)
14731 add(Pred);
14732 return;
14733 }
14734
14735 Preds.push_back(N);
14736}
14737
14739 Loop &L)
14740 : SE(SE), L(L) {
14742 Preds = std::make_unique<SCEVUnionPredicate>(Empty);
14743}
14744
14747 for (const auto *Op : Ops)
14748 // We do not expect that forgetting cached data for SCEVConstants will ever
14749 // open any prospects for sharpening or introduce any correctness issues,
14750 // so we don't bother storing their dependencies.
14751 if (!isa<SCEVConstant>(Op))
14752 SCEVUsers[Op].insert(User);
14753}
14754
14756 const SCEV *Expr = SE.getSCEV(V);
14757 RewriteEntry &Entry = RewriteMap[Expr];
14758
14759 // If we already have an entry and the version matches, return it.
14760 if (Entry.second && Generation == Entry.first)
14761 return Entry.second;
14762
14763 // We found an entry but it's stale. Rewrite the stale entry
14764 // according to the current predicate.
14765 if (Entry.second)
14766 Expr = Entry.second;
14767
14768 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
14769 Entry = {Generation, NewSCEV};
14770
14771 return NewSCEV;
14772}
14773
14775 if (!BackedgeCount) {
14777 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
14778 for (const auto *P : Preds)
14779 addPredicate(*P);
14780 }
14781 return BackedgeCount;
14782}
14783
14785 if (Preds->implies(&Pred))
14786 return;
14787
14788 auto &OldPreds = Preds->getPredicates();
14789 SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
14790 NewPreds.push_back(&Pred);
14791 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
14792 updateGeneration();
14793}
14794
14796 return *Preds;
14797}
14798
14799void PredicatedScalarEvolution::updateGeneration() {
14800 // If the generation number wrapped recompute everything.
14801 if (++Generation == 0) {
14802 for (auto &II : RewriteMap) {
14803 const SCEV *Rewritten = II.second.second;
14804 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
14805 }
14806 }
14807}
14808
14811 const SCEV *Expr = getSCEV(V);
14812 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14813
14814 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
14815
14816 // Clear the statically implied flags.
14817 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
14818 addPredicate(*SE.getWrapPredicate(AR, Flags));
14819
14820 auto II = FlagsMap.insert({V, Flags});
14821 if (!II.second)
14822 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
14823}
14824
14827 const SCEV *Expr = getSCEV(V);
14828 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14829
14831 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
14832
14833 auto II = FlagsMap.find(V);
14834
14835 if (II != FlagsMap.end())
14836 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
14837
14839}
14840
14842 const SCEV *Expr = this->getSCEV(V);
14844 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
14845
14846 if (!New)
14847 return nullptr;
14848
14849 for (const auto *P : NewPreds)
14850 addPredicate(*P);
14851
14852 RewriteMap[SE.getSCEV(V)] = {Generation, New};
14853 return New;
14854}
14855
14858 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
14859 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
14860 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
14861 for (auto I : Init.FlagsMap)
14862 FlagsMap.insert(I);
14863}
14864
14866 // For each block.
14867 for (auto *BB : L.getBlocks())
14868 for (auto &I : *BB) {
14869 if (!SE.isSCEVable(I.getType()))
14870 continue;
14871
14872 auto *Expr = SE.getSCEV(&I);
14873 auto II = RewriteMap.find(Expr);
14874
14875 if (II == RewriteMap.end())
14876 continue;
14877
14878 // Don't print things that are not interesting.
14879 if (II->second.second == Expr)
14880 continue;
14881
14882 OS.indent(Depth) << "[PSE]" << I << ":\n";
14883 OS.indent(Depth + 2) << *Expr << "\n";
14884 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
14885 }
14886}
14887
14888// Match the mathematical pattern A - (A / B) * B, where A and B can be
14889// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
14890// for URem with constant power-of-2 second operands.
14891// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
14892// 4, A / B becomes X / 8).
14893bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
14894 const SCEV *&RHS) {
14895 // Try to match 'zext (trunc A to iB) to iY', which is used
14896 // for URem with constant power-of-2 second operands. Make sure the size of
14897 // the operand A matches the size of the whole expressions.
14898 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
14899 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
14900 LHS = Trunc->getOperand();
14901 // Bail out if the type of the LHS is larger than the type of the
14902 // expression for now.
14903 if (getTypeSizeInBits(LHS->getType()) >
14904 getTypeSizeInBits(Expr->getType()))
14905 return false;
14906 if (LHS->getType() != Expr->getType())
14907 LHS = getZeroExtendExpr(LHS, Expr->getType());
14909 << getTypeSizeInBits(Trunc->getType()));
14910 return true;
14911 }
14912 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
14913 if (Add == nullptr || Add->getNumOperands() != 2)
14914 return false;
14915
14916 const SCEV *A = Add->getOperand(1);
14917 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
14918
14919 if (Mul == nullptr)
14920 return false;
14921
14922 const auto MatchURemWithDivisor = [&](const SCEV *B) {
14923 // (SomeExpr + (-(SomeExpr / B) * B)).
14924 if (Expr == getURemExpr(A, B)) {
14925 LHS = A;
14926 RHS = B;
14927 return true;
14928 }
14929 return false;
14930 };
14931
14932 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
14933 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
14934 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14935 MatchURemWithDivisor(Mul->getOperand(2));
14936
14937 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
14938 if (Mul->getNumOperands() == 2)
14939 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14940 MatchURemWithDivisor(Mul->getOperand(0)) ||
14941 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
14942 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
14943 return false;
14944}
14945
14946const SCEV *
14947ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
14948 SmallVector<BasicBlock*, 16> ExitingBlocks;
14949 L->getExitingBlocks(ExitingBlocks);
14950
14951 // Form an expression for the maximum exit count possible for this loop. We
14952 // merge the max and exact information to approximate a version of
14953 // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
14954 SmallVector<const SCEV*, 4> ExitCounts;
14955 for (BasicBlock *ExitingBB : ExitingBlocks) {
14956 const SCEV *ExitCount =
14958 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
14959 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
14960 "We should only have known counts for exiting blocks that "
14961 "dominate latch!");
14962 ExitCounts.push_back(ExitCount);
14963 }
14964 }
14965 if (ExitCounts.empty())
14966 return getCouldNotCompute();
14967 return getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
14968}
14969
14970/// A rewriter to replace SCEV expressions in Map with the corresponding entry
14971/// in the map. It skips AddRecExpr because we cannot guarantee that the
14972/// replacement is loop invariant in the loop of the AddRec.
14973class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
14975
14976public:
14979 : SCEVRewriteVisitor(SE), Map(M) {}
14980
14981 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
14982
14983 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14984 auto I = Map.find(Expr);
14985 if (I == Map.end())
14986 return Expr;
14987 return I->second;
14988 }
14989
14991 auto I = Map.find(Expr);
14992 if (I == Map.end()) {
14993 // If we didn't find the extact ZExt expr in the map, check if there's an
14994 // entry for a smaller ZExt we can use instead.
14995 Type *Ty = Expr->getType();
14996 const SCEV *Op = Expr->getOperand(0);
14997 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
14998 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
14999 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15000 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15001 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15002 auto I = Map.find(NarrowExt);
15003 if (I != Map.end())
15004 return SE.getZeroExtendExpr(I->second, Ty);
15005 Bitwidth = Bitwidth / 2;
15006 }
15007
15009 Expr);
15010 }
15011 return I->second;
15012 }
15013
15015 auto I = Map.find(Expr);
15016 if (I == Map.end())
15018 Expr);
15019 return I->second;
15020 }
15021
15022 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15023 auto I = Map.find(Expr);
15024 if (I == Map.end())
15026 return I->second;
15027 }
15028
15029 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15030 auto I = Map.find(Expr);
15031 if (I == Map.end())
15033 return I->second;
15034 }
15035};
15036
15037const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15038 SmallVector<const SCEV *> ExprsToRewrite;
15039 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15040 const SCEV *RHS,
15042 &RewriteMap) {
15043 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15044 // replacement SCEV which isn't directly implied by the structure of that
15045 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15046 // legal. See the scoping rules for flags in the header to understand why.
15047
15048 // If LHS is a constant, apply information to the other expression.
15049 if (isa<SCEVConstant>(LHS)) {
15050 std::swap(LHS, RHS);
15051 Predicate = CmpInst::getSwappedPredicate(Predicate);
15052 }
15053
15054 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15055 // create this form when combining two checks of the form (X u< C2 + C1) and
15056 // (X >=u C1).
15057 auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
15058 &ExprsToRewrite]() {
15059 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
15060 if (!AddExpr || AddExpr->getNumOperands() != 2)
15061 return false;
15062
15063 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
15064 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
15065 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15066 if (!C1 || !C2 || !LHSUnknown)
15067 return false;
15068
15069 auto ExactRegion =
15070 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15071 .sub(C1->getAPInt());
15072
15073 // Bail out, unless we have a non-wrapping, monotonic range.
15074 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15075 return false;
15076 auto I = RewriteMap.find(LHSUnknown);
15077 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15078 RewriteMap[LHSUnknown] = getUMaxExpr(
15079 getConstant(ExactRegion.getUnsignedMin()),
15080 getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
15081 ExprsToRewrite.push_back(LHSUnknown);
15082 return true;
15083 };
15084 if (MatchRangeCheckIdiom())
15085 return;
15086
15087 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15088 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15089 // the non-constant operand and in \p LHS the constant operand.
15090 auto IsMinMaxSCEVWithNonNegativeConstant =
15091 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15092 const SCEV *&RHS) {
15093 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15094 if (MinMax->getNumOperands() != 2)
15095 return false;
15096 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15097 if (C->getAPInt().isNegative())
15098 return false;
15099 SCTy = MinMax->getSCEVType();
15100 LHS = MinMax->getOperand(0);
15101 RHS = MinMax->getOperand(1);
15102 return true;
15103 }
15104 }
15105 return false;
15106 };
15107
15108 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15109 // constant, and returns their APInt in ExprVal and in DivisorVal.
15110 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15111 APInt &ExprVal, APInt &DivisorVal) {
15112 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15113 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15114 if (!ConstExpr || !ConstDivisor)
15115 return false;
15116 ExprVal = ConstExpr->getAPInt();
15117 DivisorVal = ConstDivisor->getAPInt();
15118 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15119 };
15120
15121 // Return a new SCEV that modifies \p Expr to the closest number divides by
15122 // \p Divisor and greater or equal than Expr.
15123 // For now, only handle constant Expr and Divisor.
15124 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15125 const SCEV *Divisor) {
15126 APInt ExprVal;
15127 APInt DivisorVal;
15128 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15129 return Expr;
15130 APInt Rem = ExprVal.urem(DivisorVal);
15131 if (!Rem.isZero())
15132 // return the SCEV: Expr + Divisor - Expr % Divisor
15133 return getConstant(ExprVal + DivisorVal - Rem);
15134 return Expr;
15135 };
15136
15137 // Return a new SCEV that modifies \p Expr to the closest number divides by
15138 // \p Divisor and less or equal than Expr.
15139 // For now, only handle constant Expr and Divisor.
15140 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15141 const SCEV *Divisor) {
15142 APInt ExprVal;
15143 APInt DivisorVal;
15144 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15145 return Expr;
15146 APInt Rem = ExprVal.urem(DivisorVal);
15147 // return the SCEV: Expr - Expr % Divisor
15148 return getConstant(ExprVal - Rem);
15149 };
15150
15151 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15152 // recursively. This is done by aligning up/down the constant value to the
15153 // Divisor.
15154 std::function<const SCEV *(const SCEV *, const SCEV *)>
15155 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15156 const SCEV *Divisor) {
15157 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15158 SCEVTypes SCTy;
15159 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15160 MinMaxRHS))
15161 return MinMaxExpr;
15162 auto IsMin =
15163 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15164 assert(isKnownNonNegative(MinMaxLHS) &&
15165 "Expected non-negative operand!");
15166 auto *DivisibleExpr =
15167 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15168 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15170 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15171 return getMinMaxExpr(SCTy, Ops);
15172 };
15173
15174 // If we have LHS == 0, check if LHS is computing a property of some unknown
15175 // SCEV %v which we can rewrite %v to express explicitly.
15176 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15177 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15178 RHSC->getValue()->isNullValue()) {
15179 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15180 // explicitly express that.
15181 const SCEV *URemLHS = nullptr;
15182 const SCEV *URemRHS = nullptr;
15183 if (matchURem(LHS, URemLHS, URemRHS)) {
15184 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15185 auto I = RewriteMap.find(LHSUnknown);
15186 const SCEV *RewrittenLHS =
15187 I != RewriteMap.end() ? I->second : LHSUnknown;
15188 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15189 const auto *Multiple =
15190 getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15191 RewriteMap[LHSUnknown] = Multiple;
15192 ExprsToRewrite.push_back(LHSUnknown);
15193 return;
15194 }
15195 }
15196 }
15197
15198 // Do not apply information for constants or if RHS contains an AddRec.
15199 if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
15200 return;
15201
15202 // If RHS is SCEVUnknown, make sure the information is applied to it.
15203 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15204 std::swap(LHS, RHS);
15205 Predicate = CmpInst::getSwappedPredicate(Predicate);
15206 }
15207
15208 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15209 // and \p FromRewritten are the same (i.e. there has been no rewrite
15210 // registered for \p From), then puts this value in the list of rewritten
15211 // expressions.
15212 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15213 const SCEV *To) {
15214 if (From == FromRewritten)
15215 ExprsToRewrite.push_back(From);
15216 RewriteMap[From] = To;
15217 };
15218
15219 // Checks whether \p S has already been rewritten. In that case returns the
15220 // existing rewrite because we want to chain further rewrites onto the
15221 // already rewritten value. Otherwise returns \p S.
15222 auto GetMaybeRewritten = [&](const SCEV *S) {
15223 auto I = RewriteMap.find(S);
15224 return I != RewriteMap.end() ? I->second : S;
15225 };
15226
15227 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15228 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15229 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15230 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15231 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15232 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15233 // DividesBy.
15234 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15235 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15236 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15237 if (Mul->getNumOperands() != 2)
15238 return false;
15239 auto *MulLHS = Mul->getOperand(0);
15240 auto *MulRHS = Mul->getOperand(1);
15241 if (isa<SCEVConstant>(MulLHS))
15242 std::swap(MulLHS, MulRHS);
15243 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15244 if (Div->getOperand(1) == MulRHS) {
15245 DividesBy = MulRHS;
15246 return true;
15247 }
15248 }
15249 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15250 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15251 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15252 return false;
15253 };
15254
15255 // Return true if Expr known to divide by \p DividesBy.
15256 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15257 [&](const SCEV *Expr, const SCEV *DividesBy) {
15258 if (getURemExpr(Expr, DividesBy)->isZero())
15259 return true;
15260 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15261 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15262 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15263 return false;
15264 };
15265
15266 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15267 const SCEV *DividesBy = nullptr;
15268 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15269 // Check that the whole expression is divided by DividesBy
15270 DividesBy =
15271 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15272
15273 // Collect rewrites for LHS and its transitive operands based on the
15274 // condition.
15275 // For min/max expressions, also apply the guard to its operands:
15276 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15277 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15278 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15279 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15280
15281 // We cannot express strict predicates in SCEV, so instead we replace them
15282 // with non-strict ones against plus or minus one of RHS depending on the
15283 // predicate.
15284 const SCEV *One = getOne(RHS->getType());
15285 switch (Predicate) {
15286 case CmpInst::ICMP_ULT:
15287 if (RHS->getType()->isPointerTy())
15288 return;
15289 RHS = getUMaxExpr(RHS, One);
15290 [[fallthrough]];
15291 case CmpInst::ICMP_SLT: {
15292 RHS = getMinusSCEV(RHS, One);
15293 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15294 break;
15295 }
15296 case CmpInst::ICMP_UGT:
15297 case CmpInst::ICMP_SGT:
15298 RHS = getAddExpr(RHS, One);
15299 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15300 break;
15301 case CmpInst::ICMP_ULE:
15302 case CmpInst::ICMP_SLE:
15303 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15304 break;
15305 case CmpInst::ICMP_UGE:
15306 case CmpInst::ICMP_SGE:
15307 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15308 break;
15309 default:
15310 break;
15311 }
15312
15315
15316 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15317 append_range(Worklist, S->operands());
15318 };
15319
15320 while (!Worklist.empty()) {
15321 const SCEV *From = Worklist.pop_back_val();
15322 if (isa<SCEVConstant>(From))
15323 continue;
15324 if (!Visited.insert(From).second)
15325 continue;
15326 const SCEV *FromRewritten = GetMaybeRewritten(From);
15327 const SCEV *To = nullptr;
15328
15329 switch (Predicate) {
15330 case CmpInst::ICMP_ULT:
15331 case CmpInst::ICMP_ULE:
15332 To = getUMinExpr(FromRewritten, RHS);
15333 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15334 EnqueueOperands(UMax);
15335 break;
15336 case CmpInst::ICMP_SLT:
15337 case CmpInst::ICMP_SLE:
15338 To = getSMinExpr(FromRewritten, RHS);
15339 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15340 EnqueueOperands(SMax);
15341 break;
15342 case CmpInst::ICMP_UGT:
15343 case CmpInst::ICMP_UGE:
15344 To = getUMaxExpr(FromRewritten, RHS);
15345 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15346 EnqueueOperands(UMin);
15347 break;
15348 case CmpInst::ICMP_SGT:
15349 case CmpInst::ICMP_SGE:
15350 To = getSMaxExpr(FromRewritten, RHS);
15351 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15352 EnqueueOperands(SMin);
15353 break;
15354 case CmpInst::ICMP_EQ:
15355 if (isa<SCEVConstant>(RHS))
15356 To = RHS;
15357 break;
15358 case CmpInst::ICMP_NE:
15359 if (isa<SCEVConstant>(RHS) &&
15360 cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15361 const SCEV *OneAlignedUp =
15362 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15363 To = getUMaxExpr(FromRewritten, OneAlignedUp);
15364 }
15365 break;
15366 default:
15367 break;
15368 }
15369
15370 if (To)
15371 AddRewrite(From, FromRewritten, To);
15372 }
15373 };
15374
15375 BasicBlock *Header = L->getHeader();
15377 // First, collect information from assumptions dominating the loop.
15378 for (auto &AssumeVH : AC.assumptions()) {
15379 if (!AssumeVH)
15380 continue;
15381 auto *AssumeI = cast<CallInst>(AssumeVH);
15382 if (!DT.dominates(AssumeI, Header))
15383 continue;
15384 Terms.emplace_back(AssumeI->getOperand(0), true);
15385 }
15386
15387 // Second, collect information from llvm.experimental.guards dominating the loop.
15388 auto *GuardDecl = F.getParent()->getFunction(
15389 Intrinsic::getName(Intrinsic::experimental_guard));
15390 if (GuardDecl)
15391 for (const auto *GU : GuardDecl->users())
15392 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15393 if (Guard->getFunction() == Header->getParent() && DT.dominates(Guard, Header))
15394 Terms.emplace_back(Guard->getArgOperand(0), true);
15395
15396 // Third, collect conditions from dominating branches. Starting at the loop
15397 // predecessor, climb up the predecessor chain, as long as there are
15398 // predecessors that can be found that have unique successors leading to the
15399 // original header.
15400 // TODO: share this logic with isLoopEntryGuardedByCond.
15401 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15402 L->getLoopPredecessor(), Header);
15403 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15404
15405 const BranchInst *LoopEntryPredicate =
15406 dyn_cast<BranchInst>(Pair.first->getTerminator());
15407 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15408 continue;
15409
15410 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15411 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15412 }
15413
15414 // Now apply the information from the collected conditions to RewriteMap.
15415 // Conditions are processed in reverse order, so the earliest conditions is
15416 // processed first. This ensures the SCEVs with the shortest dependency chains
15417 // are constructed first.
15419 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15420 SmallVector<Value *, 8> Worklist;
15422 Worklist.push_back(Term);
15423 while (!Worklist.empty()) {
15424 Value *Cond = Worklist.pop_back_val();
15425 if (!Visited.insert(Cond).second)
15426 continue;
15427
15428 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15429 auto Predicate =
15430 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15431 const auto *LHS = getSCEV(Cmp->getOperand(0));
15432 const auto *RHS = getSCEV(Cmp->getOperand(1));
15433 CollectCondition(Predicate, LHS, RHS, RewriteMap);
15434 continue;
15435 }
15436
15437 Value *L, *R;
15438 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15439 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15440 Worklist.push_back(L);
15441 Worklist.push_back(R);
15442 }
15443 }
15444 }
15445
15446 if (RewriteMap.empty())
15447 return Expr;
15448
15449 // Now that all rewrite information is collect, rewrite the collected
15450 // expressions with the information in the map. This applies information to
15451 // sub-expressions.
15452 if (ExprsToRewrite.size() > 1) {
15453 for (const SCEV *Expr : ExprsToRewrite) {
15454 const SCEV *RewriteTo = RewriteMap[Expr];
15455 RewriteMap.erase(Expr);
15456 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15457 RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
15458 }
15459 }
15460
15461 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15462 return Rewriter.visit(Expr);
15463}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:537
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:528
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
#define T1
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(VerifyEach)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
static bool rewrite(Function &F)
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static int CompareValueComplexity(EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2)
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static bool CollectAddOperandsWithScales(DenseMap< const SCEV *, APInt > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
This defines the Use class.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
A rewriter to replace SCEV expressions in Map with the corresponding entry in the map.
const SCEV * visitAddRecExpr(const SCEVAddRecExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
SCEVLoopGuardRewriter(ScalarEvolution &SE, DenseMap< const SCEV *, const SCEV * > &M)
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
Class for arbitrary precision integers.
Definition: APInt.h:76
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1941
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1543
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:981
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:401
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1491
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1370
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:608
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1002
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1463
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:906
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:184
APInt abs() const
Get the absolute value.
Definition: APInt.h:1737
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1160
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:358
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:444
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1636
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1439
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1089
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:187
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:194
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:307
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1144
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:197
unsigned countTrailingZeros() const
Definition: APInt.h:1597
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:334
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:805
APInt multiplicativeInverse() const
Definition: APInt.cpp:1244
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1128
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:954
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:851
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:284
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:319
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1108
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:178
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:410
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:217
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1199
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:47
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:360
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:378
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:321
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:473
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:228
iterator end() const
Definition: ArrayRef.h:154
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
iterator begin() const
Definition: ArrayRef.h:153
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:430
const Instruction & front() const
Definition: BasicBlock.h:453
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:452
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:206
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:221
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:513
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1320
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:993
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:1022
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:1023
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:1017
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:1016
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:1020
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:1018
@ ICMP_EQ
equal
Definition: InstrTypes.h:1014
@ ICMP_NE
not equal
Definition: InstrTypes.h:1015
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:1021
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:1019
bool isSigned() const
Definition: InstrTypes.h:1265
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:1167
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1314
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:1211
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:1129
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:1105
Predicate getFlippedSignednessPredicate()
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert.
Definition: InstrTypes.h:1308
bool isUnsigned() const
Definition: InstrTypes.h:1271
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:1261
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2529
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2112
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, bool InBounds=false, std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1200
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2535
static Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2523
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2098
This is the shared class of boolean and integer constants.
Definition: Constants.h:80
bool isMinusOne() const
This function will return true iff every bit in this constant is set to true.
Definition: Constants.h:217
bool isOne() const
This is just a convenience method to make client code smaller for a common case.
Definition: Constants.h:211
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:205
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:856
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:154
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:145
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:863
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:41
bool isNullValue() const
Return true if this is the value that would be returned by getNullValue.
Definition: Constants.cpp:90
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:720
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:878
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:774
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:905
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:672
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:202
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:155
bool erase(const KeyT &Val)
Definition: DenseMap.h:329
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:180
bool empty() const
Definition: DenseMap.h:98
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:151
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:145
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:220
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:289
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:319
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
const BasicBlock & getEntryBlock() const
Definition: Function.h:787
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:406
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:403
This instruction compares its operands according to the predicate given to the constructor.
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
static bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
const BasicBlock * getParent() const
Definition: Instruction.h:152
Class to represent integer types.
Definition: DerivedTypes.h:40
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:278
An instruction for reading from memory.
Definition: Instructions.h:184
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:44
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:60
Metadata node.
Definition: Metadata.h:1067
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Function * getFunction(StringRef Name) const
Look up the specified function in the module symbol table.
Definition: Module.cpp:193
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:31
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:41
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:76
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:109
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:103
iterator_range< const_block_iterator > blocks() const
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:662
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1827
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:109
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:115
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:116
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
bool implies(const SCEVPredicate *N) const override
Implementation of the SCEVPredicate interface.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
virtual bool implies(const SCEVPredicate *N) const =0
Returns true if this predicate implies N.
SCEVPredicate(const SCEVPredicate &)=default
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds)
Union predicates don't get cached so create a dummy set ID for it.
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
unsigned getSmallConstantMaxTripCount(const Loop *L)
Returns the upper bound of the loop trip count as a normal unsigned value.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
bool isKnownViaInduction(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
std::optional< bool > evaluatePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
std::optional< bool > evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVector< const SCEVPredicate *, 4 > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallPtrSetImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:321
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:342
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:366
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:427
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
size_type size() const
Definition: SmallSet.h:161
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void reserve(size_type N)
Definition: SmallVector.h:676
iterator erase(const_iterator CI)
Definition: SmallVector.h:750
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:818
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:317
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:622
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:651
TypeSize getSizeInBits() const
Definition: DataLayout.h:631
Class to represent struct types.
Definition: DerivedTypes.h:216
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:255
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:243
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:228
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Use & Op()
Definition: User.h:133
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5079
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1074
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2178
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2183
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2188
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2781
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2193
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:767
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:1029
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:822
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:299
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:189
@ ReallyHidden
Definition: CommandLine.h:139
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:450
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:470
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:31
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:456
void stable_sort(R &&Range)
Definition: STLExtras.h:1995
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1722
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7060
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2073
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:215
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2059
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1729
bool getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, const TargetLibraryInfo *TLI, ObjectSizeOpts Opts={})
Compute the size of the object pointed by Ptr.
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:419
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1118
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1108
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:116
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:244
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1914
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Definition: STLExtras.h:1986
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:293
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:191
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1849
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1921
auto predecessors(const MachineBasicBlock *BB)
bool isAllocationFn(const Value *V, const TargetLibraryInfo *TLI)
Tests if a value is a call or invoke to a library function that allocates or reallocates memory (eith...
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1879
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:26
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:297
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:104
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:434
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:40
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:376
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:192
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:141
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:125
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:101
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:291
Various options to control the behavior of getObjectSize.
bool NullIsUnknownSize
If this is true, null pointers in address space 0 will be treated as though they can't be evaluated.
bool RoundToAlign
Whether to round the result up to the alignment of allocas, byval arguments, and global variables.
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
void addPredicate(const SCEVPredicate *P)