diff options
author | vporpo <vporpodas@google.com> | 2025-01-25 08:19:27 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-25 08:19:27 -0800 |
commit | 5cb2db3b51c2a9d516d57bd2f07d9899bd5fdae7 (patch) | |
tree | f18cd3c2b988e181c27dc3947bdfd5869382c0cd | |
parent | 21f04b1458c52ba875a23b58b02cf6b1f8db0661 (diff) | |
download | llvm-5cb2db3b51c2a9d516d57bd2f07d9899bd5fdae7.zip llvm-5cb2db3b51c2a9d516d57bd2f07d9899bd5fdae7.tar.gz llvm-5cb2db3b51c2a9d516d57bd2f07d9899bd5fdae7.tar.bz2 |
[SandboxVec][Scheduler] Forbid crossing BBs (#124369)
This patch updates the scheduler to forbid scheduling across BBs. It
should eventually be able to handle this, but we disable it for now.
3 files changed, 64 insertions, 2 deletions
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h index 25432e1..0da1894 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h @@ -122,6 +122,8 @@ class Scheduler { std::optional<BasicBlock::iterator> ScheduleTopItOpt; // TODO: This is wasting memory in exchange for fast removal using a raw ptr. DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls; + /// The BB that we are currently scheduling. + BasicBlock *ScheduledBB = nullptr; /// \Returns a scheduling bundle containing \p Instrs. SchedBundle *createBundle(ArrayRef<Instruction *> Instrs); @@ -166,8 +168,10 @@ public: DAG.clear(); ReadyList.clear(); ScheduleTopItOpt = std::nullopt; + ScheduledBB = nullptr; assert(Bndls.empty() && DAG.empty() && ReadyList.empty() && - !ScheduleTopItOpt && "Expected empty state!"); + !ScheduleTopItOpt && ScheduledBB == nullptr && + "Expected empty state!"); } #ifndef NDEBUG diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp index 496521b..06c1ef6b 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp @@ -189,7 +189,13 @@ bool Scheduler::trySchedule(ArrayRef<Instruction *> Instrs) { [Instrs](Instruction *I) { return I->getParent() == (*Instrs.begin())->getParent(); }) && - "Instrs not in the same BB!"); + "Instrs not in the same BB, should have been rejected by Legality!"); + if (ScheduledBB == nullptr) + ScheduledBB = Instrs[0]->getParent(); + // We don't support crossing BBs for now. + if (any_of(Instrs, + [this](Instruction *I) { return I->getParent() != ScheduledBB; })) + return false; auto SchedState = getBndlSchedState(Instrs); switch (SchedState) { case BndlSchedState::FullyScheduled: diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp index c5e44a9..5a2b92e 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp @@ -51,6 +51,14 @@ struct SchedulerTest : public testing::Test { } }; +static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F, + StringRef Name) { + for (sandboxir::BasicBlock &BB : *F) + if (BB.getName() == Name) + return &BB; + llvm_unreachable("Expected to find basic block!"); +} + TEST_F(SchedulerTest, SchedBundle) { parseIR(C, R"IR( define void @foo(ptr %ptr, i8 %v0, i8 %v1) { @@ -237,3 +245,47 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) { EXPECT_TRUE(Sched.trySchedule({Add0, Add1})); EXPECT_TRUE(Sched.trySchedule({L0, L1})); } + +TEST_F(SchedulerTest, DontCrossBBs) { + parseIR(C, R"IR( +define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) { +bb0: + %add0 = add i8 %v0, 0 + %add1 = add i8 %v1, 1 + br label %bb1 +bb1: + store i8 %add0, ptr %ptr0 + store i8 %add1, ptr %ptr1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB0 = getBasicBlockByName(F, "bb0"); + auto *BB1 = getBasicBlockByName(F, "bb1"); + auto It = BB0->begin(); + auto *Add0 = &*It++; + auto *Add1 = &*It++; + + It = BB1->begin(); + auto *S0 = cast<sandboxir::StoreInst>(&*It++); + auto *S1 = cast<sandboxir::StoreInst>(&*It++); + auto *Ret = cast<sandboxir::ReturnInst>(&*It++); + + { + // Schedule bottom-up + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); + EXPECT_TRUE(Sched.trySchedule({Ret})); + EXPECT_TRUE(Sched.trySchedule({S0, S1})); + // Scheduling across blocks should fail. + EXPECT_FALSE(Sched.trySchedule({Add0, Add1})); + } + { + // Schedule top-down + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); + EXPECT_TRUE(Sched.trySchedule({Add0, Add1})); + // Scheduling across blocks should fail. + EXPECT_FALSE(Sched.trySchedule({S0, S1})); + } +} |