diff --git a/internal/app/poll/poller.go b/internal/app/poll/poller.go index 52ca8432..20d88e1f 100644 --- a/internal/app/poll/poller.go +++ b/internal/app/poll/poller.go @@ -12,7 +12,6 @@ import ( "sync/atomic" "time" - "gitea.com/gitea/act_runner/internal/app/run" "gitea.com/gitea/act_runner/internal/pkg/client" "gitea.com/gitea/act_runner/internal/pkg/config" "gitea.com/gitea/act_runner/internal/pkg/metrics" @@ -22,9 +21,15 @@ import ( log "github.com/sirupsen/logrus" ) +// TaskRunner abstracts task execution so the poller can be tested +// without a real runner. +type TaskRunner interface { + Run(ctx context.Context, task *runnerv1.Task) error +} + type Poller struct { client client.Client - runner *run.Runner + runner TaskRunner cfg *config.Config tasksVersion atomic.Int64 // tasksVersion used to store the version of the last task fetched from the Gitea. @@ -37,20 +42,19 @@ type Poller struct { done chan struct{} } -// workerState holds per-goroutine polling state. Backoff counters are -// per-worker so that with Capacity > 1, N workers each seeing one empty -// response don't combine into a "consecutive N empty" reading on a shared -// counter and trigger an unnecessarily long backoff. +// workerState holds the single poller's backoff state. Consecutive empty or +// error responses drive exponential backoff; a successful task fetch resets +// both counters so the next poll fires immediately. type workerState struct { consecutiveEmpty int64 consecutiveErrors int64 - // lastBackoff is the last interval reported to the PollBackoffSeconds gauge - // from this worker; used to suppress redundant no-op Set calls when the - // backoff plateaus (e.g. at FetchIntervalMax). + // lastBackoff is the last interval reported to the PollBackoffSeconds gauge; + // used to suppress redundant no-op Set calls when the backoff plateaus + // (e.g. at FetchIntervalMax). lastBackoff time.Duration } -func New(cfg *config.Config, client client.Client, runner *run.Runner) *Poller { +func New(cfg *config.Config, client client.Client, runner TaskRunner) *Poller { pollingCtx, shutdownPolling := context.WithCancel(context.Background()) jobsCtx, shutdownJobs := context.WithCancel(context.Background()) @@ -73,22 +77,57 @@ func New(cfg *config.Config, client client.Client, runner *run.Runner) *Poller { } func (p *Poller) Poll() { + sem := make(chan struct{}, p.cfg.Runner.Capacity) wg := &sync.WaitGroup{} - for i := 0; i < p.cfg.Runner.Capacity; i++ { - wg.Add(1) - go p.poll(wg) - } - wg.Wait() + s := &workerState{} - // signal that we shutdown - close(p.done) + defer func() { + wg.Wait() + close(p.done) + }() + + for { + select { + case sem <- struct{}{}: + case <-p.pollingCtx.Done(): + return + } + + task, ok := p.fetchTask(p.pollingCtx, s) + if !ok { + <-sem + if !p.waitBackoff(s) { + return + } + continue + } + + s.resetBackoff() + + wg.Add(1) + go func(t *runnerv1.Task) { + defer wg.Done() + defer func() { <-sem }() + p.runTaskWithRecover(p.jobsCtx, t) + }(task) + } } func (p *Poller) PollOnce() { - p.pollOnce(&workerState{}) - - // signal that we're done - close(p.done) + defer close(p.done) + s := &workerState{} + for { + task, ok := p.fetchTask(p.pollingCtx, s) + if !ok { + if !p.waitBackoff(s) { + return + } + continue + } + s.resetBackoff() + p.runTaskWithRecover(p.jobsCtx, task) + return + } } func (p *Poller) Shutdown(ctx context.Context) error { @@ -101,13 +140,13 @@ func (p *Poller) Shutdown(ctx context.Context) error { // our timeout for shutting down ran out case <-ctx.Done(): - // when both the timeout fires and the graceful shutdown - // completed succsfully, this branch of the select may - // fire. Do a non-blocking check here against the graceful - // shutdown status to avoid sending an error if we don't need to. - _, ok := <-p.done - if !ok { + // Both the timeout and the graceful shutdown may fire + // simultaneously. Do a non-blocking check to avoid forcing + // a shutdown when graceful already completed. + select { + case <-p.done: return nil + default: } // force a shutdown of all running jobs @@ -120,18 +159,27 @@ func (p *Poller) Shutdown(ctx context.Context) error { } } -func (p *Poller) poll(wg *sync.WaitGroup) { - defer wg.Done() - s := &workerState{} - for { - p.pollOnce(s) +func (s *workerState) resetBackoff() { + s.consecutiveEmpty = 0 + s.consecutiveErrors = 0 + s.lastBackoff = 0 +} - select { - case <-p.pollingCtx.Done(): - return - default: - continue - } +// waitBackoff sleeps for the current backoff interval (with jitter). +// Returns false if the polling context was cancelled during the wait. +func (p *Poller) waitBackoff(s *workerState) bool { + base := p.calculateInterval(s) + if base != s.lastBackoff { + metrics.PollBackoffSeconds.Set(base.Seconds()) + s.lastBackoff = base + } + timer := time.NewTimer(addJitter(base)) + select { + case <-timer.C: + return true + case <-p.pollingCtx.Done(): + timer.Stop() + return false } } @@ -167,34 +215,6 @@ func addJitter(d time.Duration) time.Duration { return d + time.Duration(jitter) } -func (p *Poller) pollOnce(s *workerState) { - for { - task, ok := p.fetchTask(p.pollingCtx, s) - if !ok { - base := p.calculateInterval(s) - if base != s.lastBackoff { - metrics.PollBackoffSeconds.Set(base.Seconds()) - s.lastBackoff = base - } - timer := time.NewTimer(addJitter(base)) - select { - case <-timer.C: - case <-p.pollingCtx.Done(): - timer.Stop() - return - } - continue - } - - // Got a task — reset backoff counters for fast subsequent polling. - s.consecutiveEmpty = 0 - s.consecutiveErrors = 0 - - p.runTaskWithRecover(p.jobsCtx, task) - return - } -} - func (p *Poller) runTaskWithRecover(ctx context.Context, task *runnerv1.Task) { defer func() { if r := recover(); r != nil { diff --git a/internal/app/poll/poller_test.go b/internal/app/poll/poller_test.go index 06060c65..6a69da49 100644 --- a/internal/app/poll/poller_test.go +++ b/internal/app/poll/poller_test.go @@ -6,6 +6,8 @@ package poll import ( "context" "errors" + "sync" + "sync/atomic" "testing" "time" @@ -19,11 +21,10 @@ import ( "github.com/stretchr/testify/require" ) -// TestPoller_PerWorkerCounters verifies that each worker maintains its own -// backoff counters. With a shared counter, N workers each seeing one empty -// response would inflate the counter to N and trigger an unnecessarily long -// backoff. With per-worker state, each worker only sees its own count. -func TestPoller_PerWorkerCounters(t *testing.T) { +// TestPoller_WorkerStateCounters verifies that workerState correctly tracks +// consecutive empty responses independently per state instance, and that +// fetchTask increments only the relevant counter. +func TestPoller_WorkerStateCounters(t *testing.T) { client := mocks.NewClient(t) client.On("FetchTask", mock.Anything, mock.Anything).Return( func(_ context.Context, _ *connect_go.Request[runnerv1.FetchTaskRequest]) (*connect_go.Response[runnerv1.FetchTaskResponse], error) { @@ -77,8 +78,8 @@ func TestPoller_FetchErrorIncrementsErrorsOnly(t *testing.T) { assert.Equal(t, int64(0), s.consecutiveEmpty) } -// TestPoller_CalculateInterval verifies the per-worker exponential backoff -// math is correctly driven by the worker's own counters. +// TestPoller_CalculateInterval verifies the exponential backoff math is +// correctly driven by the workerState counters. func TestPoller_CalculateInterval(t *testing.T) { cfg, err := config.LoadDefault("") require.NoError(t, err) @@ -106,3 +107,154 @@ func TestPoller_CalculateInterval(t *testing.T) { }) } } + +// atomicMax atomically updates target to max(target, val). +func atomicMax(target *atomic.Int64, val int64) { + for { + old := target.Load() + if val <= old || target.CompareAndSwap(old, val) { + break + } + } +} + +type mockRunner struct { + delay time.Duration + running atomic.Int64 + maxConcurrent atomic.Int64 + totalCompleted atomic.Int64 +} + +func (m *mockRunner) Run(ctx context.Context, _ *runnerv1.Task) error { + atomicMax(&m.maxConcurrent, m.running.Add(1)) + select { + case <-time.After(m.delay): + case <-ctx.Done(): + } + m.running.Add(-1) + m.totalCompleted.Add(1) + return nil +} + +// TestPoller_ConcurrencyLimitedByCapacity verifies that with capacity=3 and +// 6 available tasks, at most 3 tasks run concurrently, and FetchTask is +// never called concurrently (single poller). +func TestPoller_ConcurrencyLimitedByCapacity(t *testing.T) { + const ( + capacity = 3 + totalTasks = 6 + taskDelay = 50 * time.Millisecond + ) + + var ( + tasksReturned atomic.Int64 + fetchConcur atomic.Int64 + maxFetchConcur atomic.Int64 + ) + + cli := mocks.NewClient(t) + cli.On("FetchTask", mock.Anything, mock.Anything).Return( + func(_ context.Context, _ *connect_go.Request[runnerv1.FetchTaskRequest]) (*connect_go.Response[runnerv1.FetchTaskResponse], error) { + atomicMax(&maxFetchConcur, fetchConcur.Add(1)) + defer fetchConcur.Add(-1) + + n := tasksReturned.Add(1) + if n <= totalTasks { + return connect_go.NewResponse(&runnerv1.FetchTaskResponse{ + Task: &runnerv1.Task{Id: n}, + }), nil + } + return connect_go.NewResponse(&runnerv1.FetchTaskResponse{}), nil + }, + ) + + runner := &mockRunner{delay: taskDelay} + + cfg, err := config.LoadDefault("") + require.NoError(t, err) + cfg.Runner.Capacity = capacity + cfg.Runner.FetchInterval = 10 * time.Millisecond + cfg.Runner.FetchIntervalMax = 10 * time.Millisecond + + poller := New(cfg, cli, runner) + + var wg sync.WaitGroup + wg.Go(poller.Poll) + + require.Eventually(t, func() bool { + return runner.totalCompleted.Load() >= totalTasks + }, 2*time.Second, 10*time.Millisecond, "all tasks should complete") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err = poller.Shutdown(ctx) + require.NoError(t, err) + wg.Wait() + + assert.LessOrEqual(t, runner.maxConcurrent.Load(), int64(capacity), + "concurrent running tasks must not exceed capacity") + assert.GreaterOrEqual(t, runner.maxConcurrent.Load(), int64(2), + "with 6 tasks and capacity 3, at least 2 should overlap") + assert.Equal(t, int64(1), maxFetchConcur.Load(), + "FetchTask must never be called concurrently (single poller)") + assert.Equal(t, int64(totalTasks), runner.totalCompleted.Load(), + "all tasks should have been executed") +} + +// TestPoller_ShutdownForcesJobsOnTimeout locks in the fix for a +// pre-existing bug where Shutdown's timeout branch used a blocking +// `<-p.done` receive, leaving p.shutdownJobs() unreachable. With a +// task parked on jobsCtx and a Shutdown deadline shorter than the +// task's natural completion, Shutdown must force-cancel via +// shutdownJobs() and return ctx.Err() promptly — not block until the +// task would have finished on its own. +func TestPoller_ShutdownForcesJobsOnTimeout(t *testing.T) { + var served atomic.Bool + cli := mocks.NewClient(t) + cli.On("FetchTask", mock.Anything, mock.Anything).Return( + func(_ context.Context, _ *connect_go.Request[runnerv1.FetchTaskRequest]) (*connect_go.Response[runnerv1.FetchTaskResponse], error) { + if served.CompareAndSwap(false, true) { + return connect_go.NewResponse(&runnerv1.FetchTaskResponse{ + Task: &runnerv1.Task{Id: 1}, + }), nil + } + return connect_go.NewResponse(&runnerv1.FetchTaskResponse{}), nil + }, + ) + + // delay >> Shutdown timeout: Run only returns when jobsCtx is + // cancelled by shutdownJobs(). + runner := &mockRunner{delay: 30 * time.Second} + + cfg, err := config.LoadDefault("") + require.NoError(t, err) + cfg.Runner.Capacity = 1 + cfg.Runner.FetchInterval = 10 * time.Millisecond + cfg.Runner.FetchIntervalMax = 10 * time.Millisecond + + poller := New(cfg, cli, runner) + + var wg sync.WaitGroup + wg.Go(poller.Poll) + + require.Eventually(t, func() bool { + return runner.running.Load() == 1 + }, time.Second, 10*time.Millisecond, "task should start running") + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + start := time.Now() + err = poller.Shutdown(ctx) + elapsed := time.Since(start) + + require.ErrorIs(t, err, context.DeadlineExceeded) + // With the fix, Shutdown returns shortly after the deadline once + // the forced job unwinds. Without the fix, the blocking <-p.done + // would hang for the full 30s mockRunner delay. + assert.Less(t, elapsed, 5*time.Second, + "Shutdown must not block on the parked task; shutdownJobs() must run on timeout") + + wg.Wait() + assert.Equal(t, int64(1), runner.totalCompleted.Load(), + "the parked task must be cancelled and unwound") +} diff --git a/internal/pkg/metrics/metrics.go b/internal/pkg/metrics/metrics.go index 968ed9b8..5ce08382 100644 --- a/internal/pkg/metrics/metrics.go +++ b/internal/pkg/metrics/metrics.go @@ -89,7 +89,7 @@ var ( Namespace: Namespace, Subsystem: "poll", Name: "backoff_seconds", - Help: "Last observed polling backoff interval. With Capacity > 1, reflects whichever worker wrote last.", + Help: "Last observed polling backoff interval in seconds.", }) JobsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{