24#define DEBUG_TYPE "coro-elide"
26STATISTIC(NumOfCoroElided,
"The # of coroutine get elided.");
44 Lowerer(
Module &M) : LowererBase(M) {}
49 void collectPostSplitCoroIds(
Function *
F);
69 if (ValueTy != IntrTy) {
99 if (
auto *Call = dyn_cast<CallInst>(&
I))
101 !Call->isMustTailCall())
102 Call->setTailCall(
false);
107static std::optional<std::pair<uint64_t, Align>>
119 if (!isa<AllocaInst>(&
I))
127 "coro-elide-info-output-file shouldn't be empty");
133 llvm::errs() <<
"Error opening coro-elide-info-output-file '"
135 return std::make_unique<raw_fd_ostream>(2,
false);
155 for (
auto *CA : CoroAllocs) {
156 CA->replaceAllUsesWith(False);
157 CA->eraseFromParent();
166 auto *Frame =
new AllocaInst(FrameTy,
DL.getAllocaAddrSpace(),
"", InsertPt);
167 Frame->setAlignment(FrameAlign);
169 new BitCastInst(Frame, PointerType::getUnqual(
C),
"vFrame", InsertPt);
171 for (
auto *CB : CoroBegins) {
172 CB->replaceAllUsesWith(FrameVoidPtr);
173 CB->eraseFromParent();
183 const auto &It = DestroyAddr.find(CB);
184 assert(It != DestroyAddr.end());
187 unsigned Limit = 32 * (1 + It->second.size());
195 for (
auto *DA : It->second)
199 for (
auto *U : CB->
users()) {
201 if (isa<CoroFreeInst, CoroSubFnInst, CoroSaveInst>(U))
215 EscapingBBs.
insert(cast<Instruction>(U)->getParent());
218 bool PotentiallyEscaped =
false;
222 if (!Visited.
insert(BB).second)
228 PotentiallyEscaped |= EscapingBBs.
count(BB);
231 if (isa<ReturnInst>(BB->getTerminator()) || PotentiallyEscaped)
246 auto TI = BB->getTerminator();
250 if (isa<SwitchInst>(TI) &&
251 CoroSuspendSwitches.count(cast<SwitchInst>(TI))) {
252 Worklist.
push_back(cast<SwitchInst>(TI)->getSuccessor(1));
253 Worklist.
push_back(cast<SwitchInst>(TI)->getSuccessor(2));
257 }
while (!Worklist.
empty());
267 if (CoroAllocs.empty())
281 auto *TI =
B.getTerminator();
283 if (TI->getNumSuccessors() != 0 || isa<UnreachableInst>(TI))
291 for (
const auto &It : DestroyAddr) {
304 return DT.dominates(DA, TI->getTerminator());
307 !hasEscapePath(It.first, Terminators))
308 ReferencedCoroBegins.
insert(It.first);
314 return ReferencedCoroBegins.
size() == CoroBegins.size();
317void Lowerer::collectPostSplitCoroIds(
Function *
F) {
319 CoroSuspendSwitches.clear();
321 if (
auto *CII = dyn_cast<CoroIdInst>(&
I))
322 if (CII->getInfo().isPostSplit())
324 if (CII->getCoroutine() != CII->getFunction())
325 CoroIds.push_back(CII);
332 if (
auto *CSI = dyn_cast<CoroSuspendInst>(&
I))
333 if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) {
334 SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser());
336 CoroSuspendSwitches.insert(SWI);
350 if (
auto *CB = dyn_cast<CoroBeginInst>(U))
351 CoroBegins.push_back(CB);
352 else if (
auto *CA = dyn_cast<CoroAllocInst>(U))
353 CoroAllocs.push_back(CA);
362 if (
auto *II = dyn_cast<CoroSubFnInst>(U))
363 switch (II->getIndex()) {
365 ResumeAddr.push_back(II);
368 DestroyAddr[CB].push_back(II);
378 assert(Resumers &&
"PostSplit coro.id Info argument must refer to an array"
379 "of coroutine subfunctions");
380 auto *ResumeAddrConstant =
385 bool ShouldElide = shouldElide(CoroId->
getFunction(), DT);
388 if (
auto FrameSizeAndAlign =
392 <<
"' not elided in '"
395 <<
ore::NV(
"frame_size", FrameSizeAndAlign->first) <<
", align="
396 <<
ore::NV(
"align", FrameSizeAndAlign->second.value()) <<
")";
400 <<
"' not elided in '"
402 <<
"' (frame_size=unknown, align=unknown)";
408 for (
auto &It : DestroyAddr)
412 if (
auto FrameSizeAndAlign =
414 elideHeapAllocations(CoroId->
getFunction(), FrameSizeAndAlign->first,
415 FrameSizeAndAlign->second, AA);
430 <<
ore::NV(
"frame_size", FrameSizeAndAlign->first) <<
", align="
431 <<
ore::NV(
"align", FrameSizeAndAlign->second.value()) <<
")";
437 <<
"' not elided in '"
439 <<
"' (frame_size=unknown, align=unknown)";
452 auto &M = *
F.getParent();
458 L.collectPostSplitCoroIds(&
F);
460 if (L.CoroIds.empty())
467 bool Changed =
false;
468 for (
auto *CII : L.CoroIds)
469 Changed |= L.processCoroId(CII, AA, DT, ORE);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static void replaceWithConstant(Constant *Value, SmallVectorImpl< CoroSubFnInst * > &Users)
static Instruction * getFirstNonAllocaInTheEntryBlock(Function *F)
static cl::opt< std::string > CoroElideInfoOutputFilename("coro-elide-info-output-file", cl::value_desc("filename"), cl::desc("File to record the coroutines got elided"), cl::Hidden)
static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA)
static std::optional< std::pair< uint64_t, Align > > getFrameLayout(Function *Resume)
static bool declaresCoroElideIntrinsics(Module &M)
static std::unique_ptr< raw_fd_ostream > getOrCreateLogFile()
static bool operandReferences(CallInst *CI, AllocaInst *Frame, AAResults &AA)
This file defines the DenseMap class.
iv Induction Variable Users
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
LLVM Basic Block Representation.
InstListType::iterator iterator
Instruction iterators...
This class represents a no-op cast from one type to another.
This class represents a function call, abstracting a target machine's calling convention.
ConstantArray - Constant Array Declarations.
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static ConstantInt * getFalse(LLVMContext &Context)
This is an important base class in LLVM.
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
This class represents the llvm.coro.begin instruction.
This represents the llvm.coro.id instruction.
Function * getCoroutine() const
This class represents the llvm.coro.subfn.addr instruction.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
uint64_t getParamDereferenceableBytes(unsigned ArgNo) const
Extract the number of dereferenceable bytes for a parameter.
MaybeAlign getParamAlign(unsigned ArgNo) const
const BasicBlock * getParent() const
const Function * getFunction() const
Return the function this instruction belongs to.
This is an important class for using LLVM in a threaded context.
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
unsigned getNumCases() const
Return the number of 'cases' in this switch instruction, excluding the default case.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isPointerTy() const
True if this is an instance of PointerType.
static IntegerType * getInt8Ty(LLVMContext &C)
iterator_range< value_op_iterator > operand_values()
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
iterator_range< user_iterator > users()
StringRef getName() const
Return a constant reference to the value's name.
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
bool declaresIntrinsics(const Module &M, const std::initializer_list< StringRef >)
void replaceCoroFree(CoroIdInst *CoroId, bool Elide)
DiagnosticInfoOptimizationBase::Argument NV
@ OF_Append
The file should be opened in append mode.
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
bool replaceAndRecursivelySimplify(Instruction *I, Value *SimpleV, const TargetLibraryInfo *TLI=nullptr, const DominatorTree *DT=nullptr, AssumptionCache *AC=nullptr, SmallSetVector< Instruction *, 8 > *UnsimplifiedUsers=nullptr)
Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively.
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
RNSuccIterator< NodeRef, BlockT, RegionT > succ_begin(NodeRef Node)
RNSuccIterator< NodeRef, BlockT, RegionT > succ_end(NodeRef Node)
This struct is a compact representation of a valid (non-zero power of two) alignment.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Used in the streaming interface as the general argument type.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.