llvm.org GIT mirror llvm / 48973d7
Make SwitchInstProfUpdateWrapper safer While prof branch_weights inconsistencies are being fixed patch by patch (pass by pass) we need SwitchInstProfUpdateWrapper to be safe with respect to inconsistent metadata that can come from passes that have not been fixed yet. See the bug found by @nikic in https://reviews.llvm.org/D62126. This patch introduces one more state (called Invalid) to the wrapper class that allows users to work with the underlying SwitchInst ignoring the prof metadata changes. Created a unit test for the SwitchInstProfUpdateWrapper class. Reviewers: davidx, nikic, eraman, reames, chandlerc Reviewed By: davidx Differential Revision: https://reviews.llvm.org/D62656 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@362473 91177308-0d34-0410-b5e6-96231b3b80d8 Yevgeny Rouban a month ago
3 changed file(s) with 132 addition(s) and 24 deletion(s). Raw diff Collapse all Expand all
34383438 /// their prof branch_weights metadata.
34393439 class SwitchInstProfUpdateWrapper {
34403440 SwitchInst &SI;
3441 Optional > Weights;
3442 bool Changed = false;
3441 Optional > Weights = None;
3442
3443 // Sticky invalid state is needed to safely ignore operations with prof data
3444 // in cases where SwitchInstProfUpdateWrapper is created from SwitchInst
3445 // with inconsistent prof data. TODO: once we fix all prof data
3446 // inconsistencies we can turn invalid state to assertions.
3447 enum {
3448 Invalid,
3449 Initialized,
3450 Changed
3451 } State = Invalid;
34433452
34443453 protected:
34453454 static MDNode *getProfBranchWeightsMD(const SwitchInst &SI);
34463455
34473456 MDNode *buildProfBranchWeightsMD();
34483457
3449 Optional > getProfBranchWeights();
3458 void init();
34503459
34513460 public:
34523461 using CaseWeightOpt = Optional;
34543463 SwitchInst &operator*() { return SI; }
34553464 operator SwitchInst *() { return &SI; }
34563465
3457 SwitchInstProfUpdateWrapper(SwitchInst &SI)
3458 : SI(SI), Weights(getProfBranchWeights()) {}
3466 SwitchInstProfUpdateWrapper(SwitchInst &SI) : SI(SI) { init(); }
34593467
34603468 ~SwitchInstProfUpdateWrapper() {
3461 if (Changed)
3469 if (State == Changed)
34623470 SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD());
34633471 }
34643472
4343 #include
4444
4545 using namespace llvm;
46
47 static cl::opt SwitchInstProfUpdateWrapperStrict(
48 "switch-inst-prof-update-wrapper-strict", cl::Hidden,
49 cl::desc("Assert that prof branch_weights metadata is valid when creating "
50 "an instance of SwitchInstProfUpdateWrapper"),
51 cl::init(false));
4652
4753 //===----------------------------------------------------------------------===//
4854 // AllocaInst Class
38793885 }
38803886
38813887 MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
3882 assert(Changed && "called only if metadata has changed");
3888 assert(State == Changed && "called only if metadata has changed");
38833889
38843890 if (!Weights)
38853891 return nullptr;
38963902 return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights);
38973903 }
38983904
3899 Optional >
3900 SwitchInstProfUpdateWrapper::getProfBranchWeights() {
3905 void SwitchInstProfUpdateWrapper::init() {
39013906 MDNode *ProfileData = getProfBranchWeightsMD(SI);
3902 if (!ProfileData)
3903 return None;
3907 if (!ProfileData) {
3908 State = Initialized;
3909 return;
3910 }
3911
3912 if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) {
3913 State = Invalid;
3914 if (SwitchInstProfUpdateWrapperStrict)
3915 assert(!"number of prof branch_weights metadata operands corresponds to"
3916 " number of succesors");
3917 return;
3918 }
39043919
39053920 SmallVector Weights;
39063921 for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
39083923 uint32_t CW = C->getValue().getZExtValue();
39093924 Weights.push_back(CW);
39103925 }
3911 return Weights;
3926 State = Initialized;
3927 this->Weights = std::move(Weights);
39123928 }
39133929
39143930 SwitchInst::CaseIt
39163932 if (Weights) {
39173933 assert(SI.getNumSuccessors() == Weights->size() &&
39183934 "num of prof branch_weights must accord with num of successors");
3919 Changed = true;
3935 State = Changed;
39203936 // Copy the last case to the place of the removed one and shrink.
39213937 // This is tightly coupled with the way SwitchInst::removeCase() removes
39223938 // the cases in SwitchInst::removeCase(CaseIt).
39313947 SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
39323948 SI.addCase(OnVal, Dest);
39333949
3950 if (State == Invalid)
3951 return;
3952
39343953 if (!Weights && W && *W) {
3935 Changed = true;
3954 State = Changed;
39363955 Weights = SmallVector(SI.getNumSuccessors(), 0);
39373956 Weights.getValue()[SI.getNumSuccessors() - 1] = *W;
39383957 } else if (Weights) {
3939 Changed = true;
3958 State = Changed;
39403959 Weights.getValue().push_back(W ? *W : 0);
39413960 }
39423961 if (Weights)
39473966 SymbolTableList::iterator
39483967 SwitchInstProfUpdateWrapper::eraseFromParent() {
39493968 // Instruction is erased. Mark as unchanged to not touch it in the destructor.
3950 Changed = false;
3951
3952 if (Weights)
3953 Weights->resize(0);
3969 if (State != Invalid) {
3970 State = Initialized;
3971 if (Weights)
3972 Weights->resize(0);
3973 }
39543974 return SI.eraseFromParent();
39553975 }
39563976
39633983
39643984 void SwitchInstProfUpdateWrapper::setSuccessorWeight(
39653985 unsigned idx, SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
3966 if (!W)
3986 if (!W || State == Invalid)
39673987 return;
39683988
39693989 if (!Weights && *W)
39723992 if (Weights) {
39733993 auto &OldW = Weights.getValue()[idx];
39743994 if (*W != OldW) {
3975 Changed = true;
3995 State = Changed;
39763996 OldW = *W;
39773997 }
39783998 }
39824002 SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI,
39834003 unsigned idx) {
39844004 if (MDNode *ProfileData = getProfBranchWeightsMD(SI))
3985 return mdconst::extract(ProfileData->getOperand(idx + 1))
3986 ->getValue()
3987 .getZExtValue();
4005 if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1)
4006 return mdconst::extract(ProfileData->getOperand(idx + 1))
4007 ->getValue()
4008 .getZExtValue();
39884009
39894010 return None;
39904011 }
750750 const auto &Handle = *CCI;
751751 EXPECT_EQ(1, Handle.getCaseValue()->getSExtValue());
752752 EXPECT_EQ(BB1.get(), Handle.getCaseSuccessor());
753 }
754
755 TEST(InstructionsTest, SwitchInstProfUpdateWrapper) {
756 LLVMContext C;
757
758 std::unique_ptr BB1, BB2, BB3;
759 BB1.reset(BasicBlock::Create(C));
760 BB2.reset(BasicBlock::Create(C));
761 BB3.reset(BasicBlock::Create(C));
762
763 // We create block 0 after the others so that it gets destroyed first and
764 // clears the uses of the other basic blocks.
765 std::unique_ptr BB0(BasicBlock::Create(C));
766
767 auto *Int32Ty = Type::getInt32Ty(C);
768
769 SwitchInst *SI =
770 SwitchInst::Create(UndefValue::get(Int32Ty), BB0.get(), 4, BB0.get());
771 SI->addCase(ConstantInt::get(Int32Ty, 1), BB1.get());
772 SI->addCase(ConstantInt::get(Int32Ty, 2), BB2.get());
773 SI->setMetadata(LLVMContext::MD_prof,
774 MDBuilder(C).createBranchWeights({ 9, 1, 22 }));
775
776 {
777 SwitchInstProfUpdateWrapper SIW(*SI);
778 EXPECT_EQ(*SIW.getSuccessorWeight(0), 9u);
779 EXPECT_EQ(*SIW.getSuccessorWeight(1), 1u);
780 EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
781 SIW.setSuccessorWeight(0, 99u);
782 SIW.setSuccessorWeight(1, 11u);
783 EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
784 EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
785 EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
786 }
787
788 { // Create another wrapper and check that the data persist.
789 SwitchInstProfUpdateWrapper SIW(*SI);
790 EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
791 EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
792 EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
793 }
794
795 // Make prof data invalid by adding one extra weight.
796 SI->setMetadata(LLVMContext::MD_prof, MDBuilder(C).createBranchWeights(
797 { 99, 11, 22, 33 })); // extra
798 { // Invalid prof data makes wrapper act as if there were no prof data.
799 SwitchInstProfUpdateWrapper SIW(*SI);
800 ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue());
801 ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue());
802 ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue());
803 SIW.addCase(ConstantInt::get(Int32Ty, 3), BB3.get(), 39);
804 ASSERT_FALSE(SIW.getSuccessorWeight(3).hasValue()); // did not add weight 39
805 }
806
807 { // With added 3rd case the prof data become consistent with num of cases.
808 SwitchInstProfUpdateWrapper SIW(*SI);
809 EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
810 EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
811 EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
812 EXPECT_EQ(*SIW.getSuccessorWeight(3), 33u);
813 }
814
815 // Make prof data invalid by removing one extra weight.
816 SI->setMetadata(LLVMContext::MD_prof,
817 MDBuilder(C).createBranchWeights({ 99, 11, 22 })); // shorter
818 { // Invalid prof data makes wrapper act as if there were no prof data.
819 SwitchInstProfUpdateWrapper SIW(*SI);
820 ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue());
821 ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue());
822 ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue());
823 SIW.removeCase(SwitchInst::CaseIt(SI, 2));
824 }
825
826 { // With removed 3rd case the prof data become consistent with num of cases.
827 SwitchInstProfUpdateWrapper SIW(*SI);
828 EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
829 EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
830 EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
831 }
753832 }
754833
755834 TEST(InstructionsTest, CommuteShuffleMask) {