mas_tasks/
new_queue.rs

1// Copyright 2024 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use std::{collections::HashMap, sync::Arc};
7
8use async_trait::async_trait;
9use chrono::{DateTime, Duration, Utc};
10use cron::Schedule;
11use mas_context::LogContext;
12use mas_storage::{
13    Clock, RepositoryAccess, RepositoryError,
14    queue::{InsertableJob, Job, JobMetadata, Worker},
15};
16use mas_storage_pg::{DatabaseError, PgRepository};
17use opentelemetry::{
18    KeyValue,
19    metrics::{Counter, Histogram, UpDownCounter},
20};
21use rand::{Rng, RngCore, distributions::Uniform};
22use rand_chacha::ChaChaRng;
23use serde::de::DeserializeOwned;
24use sqlx::{
25    Acquire, Either,
26    postgres::{PgAdvisoryLock, PgListener},
27};
28use thiserror::Error;
29use tokio::{task::JoinSet, time::Instant};
30use tokio_util::sync::CancellationToken;
31use tracing::{Instrument as _, Span};
32use tracing_opentelemetry::OpenTelemetrySpanExt as _;
33use ulid::Ulid;
34
35use crate::{METER, State};
36
37type JobPayload = serde_json::Value;
38
39#[derive(Clone)]
40pub struct JobContext {
41    pub id: Ulid,
42    pub metadata: JobMetadata,
43    pub queue_name: String,
44    pub attempt: usize,
45    pub start: Instant,
46
47    #[expect(
48        dead_code,
49        reason = "we're not yet using this, but will be in the future"
50    )]
51    pub cancellation_token: CancellationToken,
52}
53
54impl JobContext {
55    pub fn span(&self) -> Span {
56        let span = tracing::info_span!(
57            parent: Span::none(),
58            "job.run",
59            job.id = %self.id,
60            job.queue.name = self.queue_name,
61            job.attempt = self.attempt,
62        );
63
64        span.add_link(self.metadata.span_context());
65
66        span
67    }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
71pub enum JobErrorDecision {
72    Retry,
73
74    #[default]
75    Fail,
76}
77
78impl std::fmt::Display for JobErrorDecision {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match self {
81            Self::Retry => f.write_str("retry"),
82            Self::Fail => f.write_str("fail"),
83        }
84    }
85}
86
87#[derive(Debug, Error)]
88#[error("Job failed to run, will {decision}")]
89pub struct JobError {
90    decision: JobErrorDecision,
91    #[source]
92    error: anyhow::Error,
93}
94
95impl JobError {
96    pub fn retry<T: Into<anyhow::Error>>(error: T) -> Self {
97        Self {
98            decision: JobErrorDecision::Retry,
99            error: error.into(),
100        }
101    }
102
103    pub fn fail<T: Into<anyhow::Error>>(error: T) -> Self {
104        Self {
105            decision: JobErrorDecision::Fail,
106            error: error.into(),
107        }
108    }
109}
110
111pub trait FromJob {
112    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error>
113    where
114        Self: Sized;
115}
116
117impl<T> FromJob for T
118where
119    T: DeserializeOwned,
120{
121    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error> {
122        serde_json::from_value(payload).map_err(Into::into)
123    }
124}
125
126#[async_trait]
127pub trait RunnableJob: FromJob + Send + 'static {
128    async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError>;
129}
130
131fn box_runnable_job<T: RunnableJob + 'static>(job: T) -> Box<dyn RunnableJob> {
132    Box::new(job)
133}
134
135#[derive(Debug, Error)]
136pub enum QueueRunnerError {
137    #[error("Failed to setup listener")]
138    SetupListener(#[source] sqlx::Error),
139
140    #[error("Failed to start transaction")]
141    StartTransaction(#[source] sqlx::Error),
142
143    #[error("Failed to commit transaction")]
144    CommitTransaction(#[source] sqlx::Error),
145
146    #[error("Failed to acquire leader lock")]
147    LeaderLock(#[source] sqlx::Error),
148
149    #[error(transparent)]
150    Repository(#[from] RepositoryError),
151
152    #[error(transparent)]
153    Database(#[from] DatabaseError),
154
155    #[error("Invalid schedule expression")]
156    InvalidSchedule(#[from] cron::error::Error),
157
158    #[error("Worker is not the leader")]
159    NotLeader,
160}
161
162// When the worker waits for a notification, we still want to wake it up every
163// second. Because we don't want all the workers to wake up at the same time, we
164// add a random jitter to the sleep duration, so they effectively sleep between
165// 0.9 and 1.1 seconds.
166const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900);
167const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100);
168
169// How many jobs can we run concurrently
170const MAX_CONCURRENT_JOBS: usize = 10;
171
172// How many jobs can we fetch at once
173const MAX_JOBS_TO_FETCH: usize = 5;
174
175// How many attempts a job should be retried
176const MAX_ATTEMPTS: usize = 10;
177
178/// Returns the delay to wait before retrying a job
179///
180/// Uses an exponential backoff: 5s, 10s, 20s, 40s, 1m20s, 2m40s, 5m20s, 10m50s,
181/// 21m40s, 43m20s
182fn retry_delay(attempt: usize) -> Duration {
183    let attempt = u32::try_from(attempt).unwrap_or(u32::MAX);
184    Duration::milliseconds(2_i64.saturating_pow(attempt) * 5_000)
185}
186
187type JobResult = (std::time::Duration, Result<(), JobError>);
188type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
189
190struct ScheduleDefinition {
191    schedule_name: &'static str,
192    expression: Schedule,
193    queue_name: &'static str,
194    payload: serde_json::Value,
195}
196
197pub struct QueueWorker {
198    rng: ChaChaRng,
199    clock: Box<dyn Clock + Send>,
200    listener: PgListener,
201    registration: Worker,
202    am_i_leader: bool,
203    last_heartbeat: DateTime<Utc>,
204    cancellation_token: CancellationToken,
205    #[expect(dead_code, reason = "This is used on Drop")]
206    cancellation_guard: tokio_util::sync::DropGuard,
207    state: State,
208    schedules: Vec<ScheduleDefinition>,
209    tracker: JobTracker,
210    wakeup_reason: Counter<u64>,
211    tick_time: Histogram<u64>,
212}
213
214impl QueueWorker {
215    #[tracing::instrument(
216        name = "worker.init",
217        skip_all,
218        fields(worker.id)
219    )]
220    pub async fn new(
221        state: State,
222        cancellation_token: CancellationToken,
223    ) -> Result<Self, QueueRunnerError> {
224        let mut rng = state.rng();
225        let clock = state.clock();
226
227        let mut listener = PgListener::connect_with(state.pool())
228            .await
229            .map_err(QueueRunnerError::SetupListener)?;
230
231        // We get notifications of leader stepping down on this channel
232        listener
233            .listen("queue_leader_stepdown")
234            .await
235            .map_err(QueueRunnerError::SetupListener)?;
236
237        // We get notifications when a job is available on this channel
238        listener
239            .listen("queue_available")
240            .await
241            .map_err(QueueRunnerError::SetupListener)?;
242
243        let txn = listener
244            .begin()
245            .await
246            .map_err(QueueRunnerError::StartTransaction)?;
247        let mut repo = PgRepository::from_conn(txn);
248
249        let registration = repo.queue_worker().register(&mut rng, &clock).await?;
250        tracing::Span::current().record("worker.id", tracing::field::display(registration.id));
251        repo.into_inner()
252            .commit()
253            .await
254            .map_err(QueueRunnerError::CommitTransaction)?;
255
256        tracing::info!(worker.id = %registration.id, "Registered worker");
257        let now = clock.now();
258
259        let wakeup_reason = METER
260            .u64_counter("job.worker.wakeups")
261            .with_description("Counts how many time the worker has been woken up, for which reason")
262            .build();
263
264        // Pre-create the reasons on the counter
265        wakeup_reason.add(0, &[KeyValue::new("reason", "sleep")]);
266        wakeup_reason.add(0, &[KeyValue::new("reason", "task")]);
267        wakeup_reason.add(0, &[KeyValue::new("reason", "notification")]);
268
269        let tick_time = METER
270            .u64_histogram("job.worker.tick_duration")
271            .with_description(
272                "How much time the worker took to tick, including performing leader duties",
273            )
274            .build();
275
276        // We put a cancellation drop guard in the structure, so that when it gets
277        // dropped, we're sure to cancel the token
278        let cancellation_guard = cancellation_token.clone().drop_guard();
279
280        Ok(Self {
281            rng,
282            clock,
283            listener,
284            registration,
285            am_i_leader: false,
286            last_heartbeat: now,
287            cancellation_token,
288            cancellation_guard,
289            state,
290            schedules: Vec::new(),
291            tracker: JobTracker::new(),
292            wakeup_reason,
293            tick_time,
294        })
295    }
296
297    pub fn register_handler<T: RunnableJob + InsertableJob>(&mut self) -> &mut Self {
298        // There is a potential panic here, which is fine as it's going to be caught
299        // within the job task
300        let factory = |payload: JobPayload| {
301            box_runnable_job(T::from_job(payload).expect("Failed to deserialize job"))
302        };
303
304        self.tracker
305            .factories
306            .insert(T::QUEUE_NAME, Arc::new(factory));
307        self
308    }
309
310    pub fn add_schedule<T: InsertableJob>(
311        &mut self,
312        schedule_name: &'static str,
313        expression: Schedule,
314        job: T,
315    ) -> &mut Self {
316        let payload = serde_json::to_value(job).expect("failed to serialize job payload");
317
318        self.schedules.push(ScheduleDefinition {
319            schedule_name,
320            expression,
321            queue_name: T::QUEUE_NAME,
322            payload,
323        });
324
325        self
326    }
327
328    pub async fn run(mut self) {
329        if let Err(e) = self.run_inner().await {
330            tracing::error!(
331                error = &e as &dyn std::error::Error,
332                "Failed to run new queue"
333            );
334        }
335    }
336
337    async fn run_inner(&mut self) -> Result<(), QueueRunnerError> {
338        self.setup_schedules().await?;
339
340        while !self.cancellation_token.is_cancelled() {
341            LogContext::new("worker-run-loop")
342                .run(|| self.run_loop())
343                .await?;
344        }
345
346        self.shutdown().await?;
347
348        Ok(())
349    }
350
351    #[tracing::instrument(name = "worker.setup_schedules", skip_all)]
352    pub async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> {
353        let schedules: Vec<_> = self.schedules.iter().map(|s| s.schedule_name).collect();
354
355        // Start a transaction on the existing PgListener connection
356        let txn = self
357            .listener
358            .begin()
359            .await
360            .map_err(QueueRunnerError::StartTransaction)?;
361
362        let mut repo = PgRepository::from_conn(txn);
363
364        // Setup the entries in the queue_schedules table
365        repo.queue_schedule().setup(&schedules).await?;
366
367        repo.into_inner()
368            .commit()
369            .await
370            .map_err(QueueRunnerError::CommitTransaction)?;
371
372        Ok(())
373    }
374
375    #[tracing::instrument(name = "worker.run_loop", skip_all)]
376    async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
377        self.wait_until_wakeup().await?;
378
379        if self.cancellation_token.is_cancelled() {
380            return Ok(());
381        }
382
383        let start = Instant::now();
384        self.tick().await?;
385
386        if self.am_i_leader {
387            self.perform_leader_duties().await?;
388        }
389
390        let elapsed = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
391        self.tick_time.record(elapsed, &[]);
392
393        Ok(())
394    }
395
396    #[tracing::instrument(name = "worker.shutdown", skip_all)]
397    async fn shutdown(&mut self) -> Result<(), QueueRunnerError> {
398        tracing::info!("Shutting down worker");
399
400        // Start a transaction on the existing PgListener connection
401        let txn = self
402            .listener
403            .begin()
404            .await
405            .map_err(QueueRunnerError::StartTransaction)?;
406
407        let mut repo = PgRepository::from_conn(txn);
408
409        // Log about any job still running
410        match self.tracker.running_jobs() {
411            0 => {}
412            1 => tracing::warn!("There is one job still running, waiting for it to finish"),
413            n => tracing::warn!("There are {n} jobs still running, waiting for them to finish"),
414        }
415
416        // TODO: we may want to introduce a timeout here, and abort the tasks if they
417        // take too long. It's fine for now, as we don't have long-running
418        // tasks, most of them are idempotent, and the only effect might be that
419        // the worker would 'dirtily' shutdown, meaning that its tasks would be
420        // considered, later retried by another worker
421
422        // Wait for all the jobs to finish
423        self.tracker
424            .process_jobs(&mut self.rng, &self.clock, &mut repo, true)
425            .await?;
426
427        // Tell the other workers we're shutting down
428        // This also releases the leader election lease
429        repo.queue_worker()
430            .shutdown(&self.clock, &self.registration)
431            .await?;
432
433        repo.into_inner()
434            .commit()
435            .await
436            .map_err(QueueRunnerError::CommitTransaction)?;
437
438        Ok(())
439    }
440
441    #[tracing::instrument(name = "worker.wait_until_wakeup", skip_all)]
442    async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> {
443        // This is to make sure we wake up every second to do the maintenance tasks
444        // We add a little bit of random jitter to the duration, so that we don't get
445        // fully synced workers waking up at the same time after each notification
446        let sleep_duration = self
447            .rng
448            .sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION));
449        let wakeup_sleep = tokio::time::sleep(sleep_duration);
450
451        tokio::select! {
452            () = self.cancellation_token.cancelled() => {
453                tracing::debug!("Woke up from cancellation");
454            },
455
456            () = wakeup_sleep => {
457                tracing::debug!("Woke up from sleep");
458                self.wakeup_reason.add(1, &[KeyValue::new("reason", "sleep")]);
459            },
460
461            () = self.tracker.collect_next_job(), if self.tracker.has_jobs() => {
462                tracing::debug!("Joined job task");
463                self.wakeup_reason.add(1, &[KeyValue::new("reason", "task")]);
464            },
465
466            notification = self.listener.recv() => {
467                self.wakeup_reason.add(1, &[KeyValue::new("reason", "notification")]);
468                match notification {
469                    Ok(notification) => {
470                        tracing::debug!(
471                            notification.channel = notification.channel(),
472                            notification.payload = notification.payload(),
473                            "Woke up from notification"
474                        );
475                    },
476                    Err(e) => {
477                        tracing::error!(error = &e as &dyn std::error::Error, "Failed to receive notification");
478                    },
479                }
480            },
481        }
482
483        Ok(())
484    }
485
486    #[tracing::instrument(
487        name = "worker.tick",
488        skip_all,
489        fields(worker.id = %self.registration.id),
490    )]
491    async fn tick(&mut self) -> Result<(), QueueRunnerError> {
492        tracing::debug!("Tick");
493        let now = self.clock.now();
494
495        // Start a transaction on the existing PgListener connection
496        let txn = self
497            .listener
498            .begin()
499            .await
500            .map_err(QueueRunnerError::StartTransaction)?;
501        let mut repo = PgRepository::from_conn(txn);
502
503        // We send a heartbeat every minute, to avoid writing to the database too often
504        // on a logged table
505        if now - self.last_heartbeat >= chrono::Duration::minutes(1) {
506            tracing::info!("Sending heartbeat");
507            repo.queue_worker()
508                .heartbeat(&self.clock, &self.registration)
509                .await?;
510            self.last_heartbeat = now;
511        }
512
513        // Remove any dead worker leader leases
514        repo.queue_worker()
515            .remove_leader_lease_if_expired(&self.clock)
516            .await?;
517
518        // Try to become (or stay) the leader
519        let leader = repo
520            .queue_worker()
521            .try_get_leader_lease(&self.clock, &self.registration)
522            .await?;
523
524        // Process any job task which finished
525        self.tracker
526            .process_jobs(&mut self.rng, &self.clock, &mut repo, false)
527            .await?;
528
529        // Compute how many jobs we should fetch at most
530        let max_jobs_to_fetch = MAX_CONCURRENT_JOBS
531            .saturating_sub(self.tracker.running_jobs())
532            .max(MAX_JOBS_TO_FETCH);
533
534        if max_jobs_to_fetch == 0 {
535            tracing::warn!("Internal job queue is full, not fetching any new jobs");
536        } else {
537            // Grab a few jobs in the queue
538            let queues = self.tracker.queues();
539            let jobs = repo
540                .queue_job()
541                .reserve(&self.clock, &self.registration, &queues, max_jobs_to_fetch)
542                .await?;
543
544            for Job {
545                id,
546                queue_name,
547                payload,
548                metadata,
549                attempt,
550            } in jobs
551            {
552                let cancellation_token = self.cancellation_token.child_token();
553                let start = Instant::now();
554                let context = JobContext {
555                    id,
556                    metadata,
557                    queue_name,
558                    attempt,
559                    start,
560                    cancellation_token,
561                };
562
563                self.tracker.spawn_job(self.state.clone(), context, payload);
564            }
565        }
566
567        // After this point, we are locking the leader table, so it's important that we
568        // commit as soon as possible to not block the other workers for too long
569        repo.into_inner()
570            .commit()
571            .await
572            .map_err(QueueRunnerError::CommitTransaction)?;
573
574        // Save the new leader state to log any change
575        if leader != self.am_i_leader {
576            // If we flipped state, log it
577            self.am_i_leader = leader;
578            if self.am_i_leader {
579                tracing::info!("I'm the leader now");
580            } else {
581                tracing::warn!("I am no longer the leader");
582            }
583        }
584
585        Ok(())
586    }
587
588    #[tracing::instrument(name = "worker.perform_leader_duties", skip_all)]
589    async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> {
590        // This should have been checked by the caller, but better safe than sorry
591        if !self.am_i_leader {
592            return Err(QueueRunnerError::NotLeader);
593        }
594
595        // Start a transaction on the existing PgListener connection
596        let txn = self
597            .listener
598            .begin()
599            .await
600            .map_err(QueueRunnerError::StartTransaction)?;
601
602        // The thing with the leader election is that it locks the table during the
603        // election, preventing other workers from going through the loop.
604        //
605        // Ideally, we would do the leader duties in the same transaction so that we
606        // make sure only one worker is doing the leader duties, but that
607        // would mean we would lock all the workers for the duration of the
608        // duties, which is not ideal.
609        //
610        // So we do the duties in a separate transaction, in which we take an advisory
611        // lock, so that in the very rare case where two workers think they are the
612        // leader, we still don't have two workers doing the duties at the same time.
613        let lock = PgAdvisoryLock::new("leader-duties");
614
615        let locked = lock
616            .try_acquire(txn)
617            .await
618            .map_err(QueueRunnerError::LeaderLock)?;
619
620        let locked = match locked {
621            Either::Left(locked) => locked,
622            Either::Right(txn) => {
623                tracing::error!("Another worker has the leader lock, aborting");
624                txn.rollback()
625                    .await
626                    .map_err(QueueRunnerError::CommitTransaction)?;
627                return Ok(());
628            }
629        };
630
631        let mut repo = PgRepository::from_conn(locked);
632
633        // Look at the state of schedules in the database
634        let schedules_status = repo.queue_schedule().list().await?;
635
636        let now = self.clock.now();
637        for schedule in &self.schedules {
638            // Find the schedule status from the database
639            let Some(schedule_status) = schedules_status
640                .iter()
641                .find(|s| s.schedule_name == schedule.schedule_name)
642            else {
643                tracing::error!(
644                    "Schedule {} was not found in the database",
645                    schedule.schedule_name
646                );
647                continue;
648            };
649
650            // Figure out if we should schedule a new job
651            if let Some(next_time) = schedule_status.last_scheduled_at {
652                if next_time > now {
653                    // We already have a job scheduled in the future, skip
654                    continue;
655                }
656
657                if schedule_status.last_scheduled_job_completed == Some(false) {
658                    // The last scheduled job has not completed yet, skip
659                    continue;
660                }
661            }
662
663            let next_tick = schedule.expression.after(&now).next().unwrap();
664
665            tracing::info!(
666                "Scheduling job for {}, next run at {}",
667                schedule.schedule_name,
668                next_tick
669            );
670
671            repo.queue_job()
672                .schedule_later(
673                    &mut self.rng,
674                    &self.clock,
675                    schedule.queue_name,
676                    schedule.payload.clone(),
677                    serde_json::json!({}),
678                    next_tick,
679                    Some(schedule.schedule_name),
680                )
681                .await?;
682        }
683
684        // We also check if the worker is dead, and if so, we shutdown all the dead
685        // workers that haven't checked in the last two minutes
686        repo.queue_worker()
687            .shutdown_dead_workers(&self.clock, Duration::minutes(2))
688            .await?;
689
690        // TODO: mark tasks those workers had as lost
691
692        // Mark all the scheduled jobs as available
693        let scheduled = repo
694            .queue_job()
695            .schedule_available_jobs(&self.clock)
696            .await?;
697        match scheduled {
698            0 => {}
699            1 => tracing::info!("One scheduled job marked as available"),
700            n => tracing::info!("{n} scheduled jobs marked as available"),
701        }
702
703        // Release the leader lock
704        let txn = repo
705            .into_inner()
706            .release_now()
707            .await
708            .map_err(QueueRunnerError::LeaderLock)?;
709
710        txn.commit()
711            .await
712            .map_err(QueueRunnerError::CommitTransaction)?;
713
714        Ok(())
715    }
716}
717
718/// Tracks running jobs
719///
720/// This is a separate structure to be able to borrow it mutably at the same
721/// time as the connection to the database is borrowed
722struct JobTracker {
723    /// Stores a mapping from the job queue name to the job factory
724    factories: HashMap<&'static str, JobFactory>,
725
726    /// A join set of all the currently running jobs
727    running_jobs: JoinSet<JobResult>,
728
729    /// Stores a mapping from the Tokio task ID to the job context
730    job_contexts: HashMap<tokio::task::Id, JobContext>,
731
732    /// Stores the last `join_next_with_id` result for processing, in case we
733    /// got woken up in `collect_next_job`
734    last_join_result: Option<Result<(tokio::task::Id, JobResult), tokio::task::JoinError>>,
735
736    /// An histogram which records the time it takes to process a job
737    job_processing_time: Histogram<u64>,
738
739    /// A counter which records the number of jobs currently in flight
740    in_flight_jobs: UpDownCounter<i64>,
741}
742
743impl JobTracker {
744    fn new() -> Self {
745        let job_processing_time = METER
746            .u64_histogram("job.process.duration")
747            .with_description("The time it takes to process a job in milliseconds")
748            .with_unit("ms")
749            .build();
750
751        let in_flight_jobs = METER
752            .i64_up_down_counter("job.active_tasks")
753            .with_description("The number of jobs currently in flight")
754            .with_unit("{job}")
755            .build();
756
757        Self {
758            factories: HashMap::new(),
759            running_jobs: JoinSet::new(),
760            job_contexts: HashMap::new(),
761            last_join_result: None,
762            job_processing_time,
763            in_flight_jobs,
764        }
765    }
766
767    /// Returns the queue names that are currently being tracked
768    fn queues(&self) -> Vec<&'static str> {
769        self.factories.keys().copied().collect()
770    }
771
772    /// Spawn a job on the job tracker
773    fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) {
774        let factory = self.factories.get(context.queue_name.as_str()).cloned();
775        let task = {
776            let log_context = LogContext::new(format!("job-{}", context.queue_name));
777            let context = context.clone();
778            let span = context.span();
779            log_context
780                .run(async move || {
781                    // We should never crash, but in case we do, we do that in the task and
782                    // don't crash the worker
783                    let job = factory.expect("unknown job factory")(payload);
784                    tracing::info!(
785                        job.id = %context.id,
786                        job.queue.name = %context.queue_name,
787                        job.attempt = %context.attempt,
788                        "Running job"
789                    );
790                    let result = job.run(&state, context.clone()).await;
791
792                    let Some(context_stats) =
793                        LogContext::maybe_with(mas_context::LogContext::stats)
794                    else {
795                        // This should never happen, but if it does it's fine: we're recovering fine
796                        // from panics in those tasks
797                        panic!("Missing log context, this should never happen");
798                    };
799
800                    // We log the result here so that it's attached to the right span & log context
801                    match &result {
802                        Ok(()) => {
803                            tracing::info!(
804                                job.id = %context.id,
805                                job.queue.name = %context.queue_name,
806                                job.attempt = %context.attempt,
807                                "Job completed [{context_stats}]"
808                            );
809                        }
810
811                        Err(JobError {
812                            decision: JobErrorDecision::Fail,
813                            error,
814                        }) => {
815                            tracing::error!(
816                                error = &**error as &dyn std::error::Error,
817                                job.id = %context.id,
818                                job.queue.name = %context.queue_name,
819                                job.attempt = %context.attempt,
820                                "Job failed, not retrying [{context_stats}]"
821                            );
822                        }
823
824                        Err(JobError {
825                            decision: JobErrorDecision::Retry,
826                            error,
827                        }) if context.attempt < MAX_ATTEMPTS => {
828                            let delay = retry_delay(context.attempt);
829                            tracing::warn!(
830                                error = &**error as &dyn std::error::Error,
831                                job.id = %context.id,
832                                job.queue.name = %context.queue_name,
833                                job.attempt = %context.attempt,
834                                "Job failed, will retry in {}s [{context_stats}]",
835                                delay.num_seconds()
836                            );
837                        }
838
839                        Err(JobError {
840                            decision: JobErrorDecision::Retry,
841                            error,
842                        }) => {
843                            tracing::error!(
844                                error = &**error as &dyn std::error::Error,
845                                job.id = %context.id,
846                                job.queue.name = %context.queue_name,
847                                job.attempt = %context.attempt,
848                                "Job failed too many times, abandonning [{context_stats}]"
849                            );
850                        }
851                    }
852
853                    (context_stats.elapsed, result)
854                })
855                .instrument(span)
856        };
857
858        self.in_flight_jobs.add(
859            1,
860            &[KeyValue::new("job.queue.name", context.queue_name.clone())],
861        );
862
863        let handle = self.running_jobs.spawn(task);
864        self.job_contexts.insert(handle.id(), context);
865    }
866
867    /// Returns `true` if there are currently running jobs
868    fn has_jobs(&self) -> bool {
869        !self.running_jobs.is_empty()
870    }
871
872    /// Returns the number of currently running jobs
873    ///
874    /// This also includes the job result which may be stored for processing
875    fn running_jobs(&self) -> usize {
876        self.running_jobs.len() + usize::from(self.last_join_result.is_some())
877    }
878
879    async fn collect_next_job(&mut self) {
880        // Double-check that we don't have a job result stored
881        if self.last_join_result.is_some() {
882            tracing::error!(
883                "Job tracker already had a job result stored, this should never happen!"
884            );
885            return;
886        }
887
888        self.last_join_result = self.running_jobs.join_next_with_id().await;
889    }
890
891    /// Process all the jobs which are currently running
892    ///
893    /// If `blocking` is `true`, this function will block until all the jobs
894    /// are finished. Otherwise, it will return as soon as it processed the
895    /// already finished jobs.
896    #[allow(clippy::too_many_lines)]
897    async fn process_jobs<E: std::error::Error + Send + Sync + 'static>(
898        &mut self,
899        rng: &mut (dyn RngCore + Send),
900        clock: &dyn Clock,
901        repo: &mut dyn RepositoryAccess<Error = E>,
902        blocking: bool,
903    ) -> Result<(), E> {
904        if self.last_join_result.is_none() {
905            if blocking {
906                self.last_join_result = self.running_jobs.join_next_with_id().await;
907            } else {
908                self.last_join_result = self.running_jobs.try_join_next_with_id();
909            }
910        }
911
912        while let Some(result) = self.last_join_result.take() {
913            match result {
914                // The job succeeded. The logging and time measurement is already done in the task
915                Ok((id, (elapsed, Ok(())))) => {
916                    let context = self
917                        .job_contexts
918                        .remove(&id)
919                        .expect("Job context not found");
920
921                    self.in_flight_jobs.add(
922                        -1,
923                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
924                    );
925
926                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
927                    self.job_processing_time.record(
928                        elapsed_ms,
929                        &[
930                            KeyValue::new("job.queue.name", context.queue_name),
931                            KeyValue::new("job.result", "success"),
932                        ],
933                    );
934
935                    repo.queue_job()
936                        .mark_as_completed(clock, context.id)
937                        .await?;
938                }
939
940                // The job failed. The logging and time measurement is already done in the task
941                Ok((id, (elapsed, Err(e)))) => {
942                    let context = self
943                        .job_contexts
944                        .remove(&id)
945                        .expect("Job context not found");
946
947                    self.in_flight_jobs.add(
948                        -1,
949                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
950                    );
951
952                    let reason = format!("{:?}", e.error);
953                    repo.queue_job()
954                        .mark_as_failed(clock, context.id, &reason)
955                        .await?;
956
957                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
958                    match e.decision {
959                        JobErrorDecision::Fail => {
960                            self.job_processing_time.record(
961                                elapsed_ms,
962                                &[
963                                    KeyValue::new("job.queue.name", context.queue_name),
964                                    KeyValue::new("job.result", "failed"),
965                                    KeyValue::new("job.decision", "fail"),
966                                ],
967                            );
968                        }
969
970                        JobErrorDecision::Retry if context.attempt < MAX_ATTEMPTS => {
971                            self.job_processing_time.record(
972                                elapsed_ms,
973                                &[
974                                    KeyValue::new("job.queue.name", context.queue_name),
975                                    KeyValue::new("job.result", "failed"),
976                                    KeyValue::new("job.decision", "retry"),
977                                ],
978                            );
979
980                            let delay = retry_delay(context.attempt);
981                            repo.queue_job()
982                                .retry(&mut *rng, clock, context.id, delay)
983                                .await?;
984                        }
985
986                        JobErrorDecision::Retry => {
987                            self.job_processing_time.record(
988                                elapsed_ms,
989                                &[
990                                    KeyValue::new("job.queue.name", context.queue_name),
991                                    KeyValue::new("job.result", "failed"),
992                                    KeyValue::new("job.decision", "abandon"),
993                                ],
994                            );
995                        }
996                    }
997                }
998
999                // The job crashed (or was aborted)
1000                Err(e) => {
1001                    let id = e.id();
1002                    let context = self
1003                        .job_contexts
1004                        .remove(&id)
1005                        .expect("Job context not found");
1006
1007                    self.in_flight_jobs.add(
1008                        -1,
1009                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1010                    );
1011
1012                    // This measurement is not accurate as it includes the time processing the jobs,
1013                    // but it's fine, it's only for panicked tasks
1014                    let elapsed = context
1015                        .start
1016                        .elapsed()
1017                        .as_millis()
1018                        .try_into()
1019                        .unwrap_or(u64::MAX);
1020
1021                    let reason = e.to_string();
1022                    repo.queue_job()
1023                        .mark_as_failed(clock, context.id, &reason)
1024                        .await?;
1025
1026                    if context.attempt < MAX_ATTEMPTS {
1027                        let delay = retry_delay(context.attempt);
1028                        tracing::error!(
1029                            error = &e as &dyn std::error::Error,
1030                            job.id = %context.id,
1031                            job.queue.name = %context.queue_name,
1032                            job.attempt = %context.attempt,
1033                            job.elapsed = format!("{elapsed}ms"),
1034                            "Job crashed, will retry in {}s",
1035                            delay.num_seconds()
1036                        );
1037
1038                        self.job_processing_time.record(
1039                            elapsed,
1040                            &[
1041                                KeyValue::new("job.queue.name", context.queue_name),
1042                                KeyValue::new("job.result", "crashed"),
1043                                KeyValue::new("job.decision", "retry"),
1044                            ],
1045                        );
1046
1047                        repo.queue_job()
1048                            .retry(&mut *rng, clock, context.id, delay)
1049                            .await?;
1050                    } else {
1051                        tracing::error!(
1052                            error = &e as &dyn std::error::Error,
1053                            job.id = %context.id,
1054                            job.queue.name = %context.queue_name,
1055                            job.attempt = %context.attempt,
1056                            job.elapsed = format!("{elapsed}ms"),
1057                            "Job crashed too many times, abandonning"
1058                        );
1059
1060                        self.job_processing_time.record(
1061                            elapsed,
1062                            &[
1063                                KeyValue::new("job.queue.name", context.queue_name),
1064                                KeyValue::new("job.result", "crashed"),
1065                                KeyValue::new("job.decision", "abandon"),
1066                            ],
1067                        );
1068                    }
1069                }
1070            }
1071
1072            if blocking {
1073                self.last_join_result = self.running_jobs.join_next_with_id().await;
1074            } else {
1075                self.last_join_result = self.running_jobs.try_join_next_with_id();
1076            }
1077        }
1078
1079        Ok(())
1080    }
1081}