Sfoglia il codice sorgente

Closes #21952: Improve robustness of RQ worker check

Jeremy Stretch 2 giorni fa
parent
commit
f391c958f5

+ 3 - 2
netbox/core/views.py

@@ -14,7 +14,7 @@ from django.shortcuts import get_object_or_404, redirect, render
 from django.urls import reverse
 from django.utils.translation import gettext_lazy as _
 from django.views.generic import View
-from django_rq.queues import get_connection, get_queue_by_index, get_redis_connection
+from django_rq.queues import get_queue_by_index, get_redis_connection
 from django_rq.settings import get_queues_list, get_queues_map
 from django_rq.utils import get_statistics
 from rq.exceptions import NoSuchJobError
@@ -55,6 +55,7 @@ from utilities.forms import ConfirmationForm
 from utilities.htmx import htmx_partial
 from utilities.json import ConfigJSONEncoder
 from utilities.query import count_related
+from utilities.rqworker import get_workers_for_queue
 from utilities.views import (
     ContentTypePermissionRequiredMixin,
     GetRelatedModelsMixin,
@@ -707,7 +708,7 @@ class SystemView(UserPassesTestMixin, View):
             'postgresql_version': psql_version,
             'database_name': db_name,
             'database_size': db_size,
-            'rq_worker_count': Worker.count(get_connection('default')),
+            'rq_worker_count': len(get_workers_for_queue('default')),
         }
 
     def _get_object_counts(self):

+ 2 - 3
netbox/extras/api/views.py

@@ -1,6 +1,5 @@
 from django.http import Http404
 from django.shortcuts import get_object_or_404
-from django_rq.queues import get_connection
 from drf_spectacular.utils import extend_schema, extend_schema_view
 from rest_framework import status
 from rest_framework.decorators import action
@@ -11,7 +10,6 @@ from rest_framework.renderers import JSONRenderer
 from rest_framework.response import Response
 from rest_framework.routers import APIRootView
 from rest_framework.viewsets import ModelViewSet
-from rq import Worker
 
 from extras import filtersets
 from extras.jobs import ScriptJob
@@ -24,6 +22,7 @@ from netbox.api.viewsets import BaseViewSet, NetBoxModelViewSet
 from netbox.api.viewsets.mixins import ObjectValidationMixin
 from utilities.exceptions import RQWorkerNotRunningException
 from utilities.request import copy_safe_request
+from utilities.rqworker import any_workers_for_queue
 
 from . import serializers
 from .mixins import ConfigTemplateRenderMixin
@@ -326,7 +325,7 @@ class ScriptViewSet(ModelViewSet):
         )
 
         # Check that at least one RQ worker is running
-        if not Worker.count(get_connection('default')):
+        if not any_workers_for_queue('default'):
             raise RQWorkerNotRunningException()
 
         if input_serializer.is_valid():

+ 3 - 3
netbox/extras/tests/test_views.py

@@ -995,7 +995,7 @@ class ScriptValidationErrorTestCase(TestCase):
     def test_script_validation_error_displays_message(self):
         url = reverse('extras:script', kwargs={'pk': self.script.pk})
 
-        with patch('extras.views.get_workers_for_queue', return_value=['worker']):
+        with patch('extras.views.any_workers_for_queue', return_value=True):
             response = self.client.post(url, {'debug_mode': 'true', '_commit': 'true'})
 
         self.assertEqual(response.status_code, 200)
@@ -1020,7 +1020,7 @@ class ScriptValidationErrorTestCase(TestCase):
 
         with patch.object(Script, 'python_class', new_callable=PropertyMock) as mock_python_class:
             mock_python_class.return_value = FieldsetScript
-            with patch('extras.views.get_workers_for_queue', return_value=['worker']):
+            with patch('extras.views.any_workers_for_queue', return_value=True):
                 response = self.client.post(url, {'required_field': '5', '_commit': 'true'})
 
         self.assertEqual(response.status_code, 200)
@@ -1063,7 +1063,7 @@ class ScriptDefaultValuesTestCase(TestCase):
     def test_default_values_are_used(self):
         url = reverse('extras:script', kwargs={'pk': self.script.pk})
 
-        with patch('extras.views.get_workers_for_queue', return_value=['worker']):
+        with patch('extras.views.any_workers_for_queue', return_value=True):
             with patch('extras.jobs.ScriptJob.enqueue') as mock_enqueue:
                 mock_enqueue.return_value.pk = 1
                 self.client.post(url, {})

+ 2 - 2
netbox/extras/views.py

@@ -38,7 +38,7 @@ from utilities.paginator import EnhancedPaginator, get_paginate_count
 from utilities.query import count_related
 from utilities.querydict import normalize_querydict
 from utilities.request import copy_safe_request
-from utilities.rqworker import get_workers_for_queue
+from utilities.rqworker import any_workers_for_queue
 from utilities.templatetags.builtins.filters import render_markdown
 from utilities.views import ContentTypePermissionRequiredMixin, get_action_url, register_model_view
 from virtualization.models import VirtualMachine
@@ -1729,7 +1729,7 @@ class ScriptView(BaseScriptView):
         form = script_class.as_form(post_data, request.FILES)
 
         # Allow execution only if RQ worker process is running
-        if not get_workers_for_queue('default'):
+        if not any_workers_for_queue('default'):
             messages.error(request, _("Unable to run script: RQ worker process not running."))
         elif form.is_valid():
             ScriptJob = import_string("extras.jobs.ScriptJob")

+ 2 - 3
netbox/netbox/api/views.py

@@ -2,19 +2,18 @@ import platform
 
 from django import __version__ as DJANGO_VERSION
 from django.conf import settings
-from django_rq.queues import get_connection
 from drf_spectacular.types import OpenApiTypes
 from drf_spectacular.utils import extend_schema
 from rest_framework.permissions import IsAuthenticated
 from rest_framework.response import Response
 from rest_framework.reverse import reverse
 from rest_framework.views import APIView
-from rq.worker import Worker
 
 from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
 from netbox.plugins.utils import get_installed_plugins
 from users.api.serializers import UserSerializer
 from utilities.apps import get_installed_apps
+from utilities.rqworker import get_workers_for_queue
 
 
 class APIRootView(APIView):
@@ -62,7 +61,7 @@ class StatusView(APIView):
             'netbox-full-version': settings.RELEASE.full_version,
             'plugins': get_installed_plugins(),
             'python-version': platform.python_version(),
-            'rq-workers-running': Worker.count(get_connection('default')),
+            'rq-workers-running': len(get_workers_for_queue('default')),
         })
 
 

+ 1 - 0
netbox/netbox/settings.py

@@ -176,6 +176,7 @@ REMOTE_AUTH_USER_LAST_NAME = getattr(configuration, 'REMOTE_AUTH_USER_LAST_NAME'
 # Required by extras/migrations/0109_script_models.py
 REPORTS_ROOT = getattr(configuration, 'REPORTS_ROOT', os.path.join(BASE_DIR, 'reports')).rstrip('/')
 RQ = getattr(configuration, 'RQ', {})
+RQ.setdefault('WORKER_CLASS', 'utilities.rqworker.NetBoxRQWorker')
 RQ_DEFAULT_TIMEOUT = getattr(configuration, 'RQ_DEFAULT_TIMEOUT', 300)
 RQ_RETRY_INTERVAL = getattr(configuration, 'RQ_RETRY_INTERVAL', 60)
 RQ_RETRY_MAX = getattr(configuration, 'RQ_RETRY_MAX', 0)

+ 85 - 2
netbox/utilities/rqworker.py

@@ -1,15 +1,56 @@
+import logging
+from datetime import timedelta
+
+from django.utils.timezone import now
 from django_rq.queues import get_connection
 from rq import Retry, Worker
+from rq.worker_registration import REDIS_WORKER_KEYS
+from rq.worker_registration import register as register_worker
 
 from netbox.config import get_config
 from netbox.constants import RQ_QUEUE_DEFAULT
 
 __all__ = (
+    'NetBoxRQWorker',
+    'any_workers_for_queue',
     'get_queue_for_model',
     'get_rq_retry',
     'get_workers_for_queue',
 )
 
+logger = logging.getLogger('netbox.rqworker')
+
+# Matches rq.defaults.DEFAULT_WORKER_TTL; used as a fallback when a worker's
+# own worker_ttl is unavailable.
+DEFAULT_WORKER_TTL = 420
+
+
+class NetBoxRQWorker(Worker):
+    """
+    RQ worker subclass which self-heals its registration. If the worker's
+    registration is missing from Redis (e.g. because the tasks Redis database
+    was lost and rebuilt while the worker was running), the next heartbeat
+    will re-register the worker so that Worker.all() / Worker.count() can find
+    it again.
+    """
+
+    def heartbeat(self, timeout=None, pipeline=None):
+        try:
+            if not self.connection.sismember(REDIS_WORKER_KEYS, self.key):
+                logger.warning(f"Worker {self.name} not found in registry; re-registering.")
+                # If the worker hash still exists (partial Redis data loss),
+                # register_birth() would raise because rq treats an existing,
+                # non-dead hash as an active worker. Re-add to the registry
+                # sets directly in that case; the heartbeat below will refresh
+                # the hash TTL.
+                if self.connection.exists(self.key) and not self.connection.hexists(self.key, 'death'):
+                    register_worker(self, self.connection)
+                else:
+                    self.register_birth()
+        except Exception:
+            logger.exception("Failed to verify worker registration.")
+        super().heartbeat(timeout=timeout, pipeline=pipeline)
+
 
 def get_queue_for_model(model):
     """
@@ -18,11 +59,53 @@ def get_queue_for_model(model):
     return get_config().QUEUE_MAPPINGS.get(model, RQ_QUEUE_DEFAULT)
 
 
+def _is_live_worker(worker, queue_name, threshold):
+    """
+    Return True if the given Worker is currently servicing queue_name and its
+    last heartbeat is within (worker_ttl + 60s) of threshold.
+    """
+    if queue_name not in worker.queue_names():
+        return False
+    last = worker.last_heartbeat
+    if last is None:
+        return False
+    ttl = (worker.worker_ttl or DEFAULT_WORKER_TTL) + 60
+    return (threshold - last) <= timedelta(seconds=ttl)
+
+
 def get_workers_for_queue(queue_name):
     """
-    Returns True if a worker process is currently servicing the specified queue.
+    Return the set of worker names currently servicing the given queue.
+
+    Workers are filtered by last_heartbeat freshness: any worker whose most
+    recent heartbeat exceeds (worker_ttl + 60s) is excluded. This prevents
+    stale registrations (workers that died without deregistering) from being
+    reported as alive.
+    """
+    connection = get_connection(queue_name)
+    threshold = now()
+    return {
+        worker.name for worker in Worker.all(connection=connection)
+        if _is_live_worker(worker, queue_name, threshold)
+    }
+
+
+def any_workers_for_queue(queue_name):
+    """
+    Return True if at least one live worker is currently servicing the given
+    queue. Cheaper than get_workers_for_queue() when only a liveness check is
+    needed: workers are fetched one at a time and iteration stops at the first
+    live match.
     """
-    return Worker.count(get_connection(queue_name))
+    connection = get_connection(queue_name)
+    threshold = now()
+    for key in Worker.all_keys(connection=connection):
+        worker = Worker.find_by_key(key, connection=connection)
+        if worker is None:
+            continue
+        if _is_live_worker(worker, queue_name, threshold):
+            return True
+    return False
 
 
 def get_rq_retry():

+ 237 - 0
netbox/utilities/tests/test_rqworker.py

@@ -0,0 +1,237 @@
+from datetime import timedelta
+from unittest.mock import MagicMock, patch
+
+from django.test import TestCase
+from django.utils.timezone import now
+
+from utilities.rqworker import (
+    DEFAULT_WORKER_TTL,
+    NetBoxRQWorker,
+    any_workers_for_queue,
+    get_workers_for_queue,
+)
+
+
+def _make_worker(name='worker-1', queues=('default',), last_heartbeat_age_seconds=10, worker_ttl=DEFAULT_WORKER_TTL):
+    """
+    Build a MagicMock that mimics the rq.Worker attributes consumed by
+    get_workers_for_queue().
+    """
+    worker = MagicMock()
+    worker.name = name
+    worker.queue_names.return_value = list(queues)
+    worker.last_heartbeat = (
+        now() - timedelta(seconds=last_heartbeat_age_seconds)
+        if last_heartbeat_age_seconds is not None else None
+    )
+    worker.worker_ttl = worker_ttl
+    return worker
+
+
+class NetBoxRQWorkerHeartbeatTestCase(TestCase):
+    """
+    The overridden heartbeat() must call register_birth() iff the worker is
+    missing from the rq:workers registry set, and must always invoke
+    super().heartbeat().
+    """
+
+    def _make_subject(self, is_member, hash_exists=False, marked_dead=False):
+        worker = NetBoxRQWorker.__new__(NetBoxRQWorker)
+        worker.name = 'test-worker'
+        worker.connection = MagicMock()
+        worker.connection.sismember.return_value = is_member
+        worker.connection.exists.return_value = hash_exists
+        worker.connection.hexists.return_value = marked_dead
+        worker.register_birth = MagicMock()
+        worker.log = MagicMock()
+        return worker
+
+    def test_heartbeat_skips_register_when_present(self):
+        worker = self._make_subject(is_member=True)
+        with patch('rq.Worker.heartbeat') as super_heartbeat:
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_not_called()
+        super_heartbeat.assert_called_once()
+
+    def test_heartbeat_calls_register_birth_when_hash_missing(self):
+        # Full data loss: set membership and hash both gone.
+        worker = self._make_subject(is_member=False, hash_exists=False)
+        with patch('rq.Worker.heartbeat') as super_heartbeat, \
+                patch('utilities.rqworker.register_worker') as register_set:
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_called_once()
+        register_set.assert_not_called()
+        super_heartbeat.assert_called_once()
+
+    def test_heartbeat_readds_to_set_when_hash_survives(self):
+        # Partial data loss: hash present (and not dead), set membership gone.
+        # register_birth() would raise here; we must re-add to the set instead.
+        worker = self._make_subject(is_member=False, hash_exists=True, marked_dead=False)
+        with patch('rq.Worker.heartbeat') as super_heartbeat, \
+                patch('utilities.rqworker.register_worker') as register_set:
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_not_called()
+        register_set.assert_called_once_with(worker, worker.connection)
+        # Liveness is gated on the 'death' hash field specifically; pin the
+        # field name so a typo can't silently fall through to register_birth().
+        worker.connection.hexists.assert_called_with(worker.key, 'death')
+        super_heartbeat.assert_called_once()
+
+    def test_heartbeat_calls_register_birth_when_hash_marked_dead(self):
+        # Hash exists but is marked dead -- treat as full recreate.
+        worker = self._make_subject(is_member=False, hash_exists=True, marked_dead=True)
+        with patch('rq.Worker.heartbeat') as super_heartbeat, \
+                patch('utilities.rqworker.register_worker') as register_set:
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_called_once()
+        register_set.assert_not_called()
+        super_heartbeat.assert_called_once()
+
+    def test_heartbeat_tolerates_redis_exception(self):
+        worker = self._make_subject(is_member=False)
+        worker.connection.sismember.side_effect = RuntimeError('redis down')
+        with patch('rq.Worker.heartbeat') as super_heartbeat:
+            # Must not raise
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_not_called()
+        super_heartbeat.assert_called_once()
+
+
+class GetWorkersForQueueTestCase(TestCase):
+    """
+    get_workers_for_queue() must:
+      * include workers servicing the queue whose last_heartbeat is fresh
+      * exclude workers whose last_heartbeat is stale
+      * exclude workers whose last_heartbeat is None
+      * exclude workers not listening on the requested queue
+      * return an empty set when no workers exist
+    """
+
+    def _patch_worker_all(self, workers):
+        return patch('utilities.rqworker.Worker.all', return_value=workers)
+
+    def _patch_get_connection(self):
+        return patch('utilities.rqworker.get_connection', return_value=MagicMock())
+
+    def test_returns_fresh_worker_for_queue(self):
+        workers = [_make_worker(name='alive', last_heartbeat_age_seconds=10)]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, {'alive'})
+
+    def test_excludes_stale_worker(self):
+        # Heartbeat older than worker_ttl + 60
+        stale_age = DEFAULT_WORKER_TTL + 120
+        workers = [_make_worker(name='stale', last_heartbeat_age_seconds=stale_age)]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, set())
+
+    def test_excludes_worker_with_no_heartbeat(self):
+        workers = [_make_worker(name='cold', last_heartbeat_age_seconds=None)]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, set())
+
+    def test_excludes_worker_for_other_queue(self):
+        workers = [_make_worker(name='other', queues=('high',))]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, set())
+
+    def test_returns_empty_when_no_workers(self):
+        with self._patch_get_connection(), self._patch_worker_all([]):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, set())
+
+    def test_includes_worker_listening_on_multiple_queues(self):
+        workers = [_make_worker(name='multi', queues=('high', 'default', 'low'))]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, {'multi'})
+
+    def test_includes_worker_with_null_ttl(self):
+        # When a worker reports worker_ttl=None, the freshness window must
+        # fall back to DEFAULT_WORKER_TTL rather than raising on `None + 60`.
+        workers = [_make_worker(name='nullttl', last_heartbeat_age_seconds=10, worker_ttl=None)]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, {'nullttl'})
+
+    def test_mixed_fresh_and_stale_workers(self):
+        workers = [
+            _make_worker(name='alive', last_heartbeat_age_seconds=10),
+            _make_worker(name='stale', last_heartbeat_age_seconds=DEFAULT_WORKER_TTL + 120),
+            _make_worker(name='other-queue', queues=('high',)),
+        ]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, {'alive'})
+
+
+class AnyWorkersForQueueTestCase(TestCase):
+    """
+    any_workers_for_queue() must apply the same liveness filter as
+    get_workers_for_queue(), but short-circuit on the first live match.
+    """
+
+    def _patch_keys_and_lookup(self, workers):
+        keys = [f'rq:worker:{w.name}' for w in workers]
+        by_key = dict(zip(keys, workers))
+        return (
+            patch('utilities.rqworker.Worker.all_keys', return_value=keys),
+            patch('utilities.rqworker.Worker.find_by_key', side_effect=lambda key, connection=None: by_key.get(key)),
+        )
+
+    def _patch_get_connection(self):
+        return patch('utilities.rqworker.get_connection', return_value=MagicMock())
+
+    def test_returns_true_when_fresh_worker_present(self):
+        workers = [_make_worker(name='alive', last_heartbeat_age_seconds=10)]
+        keys_patch, find_patch = self._patch_keys_and_lookup(workers)
+        with self._patch_get_connection(), keys_patch, find_patch:
+            self.assertTrue(any_workers_for_queue('default'))
+
+    def test_returns_false_when_only_stale_workers(self):
+        workers = [_make_worker(name='stale', last_heartbeat_age_seconds=DEFAULT_WORKER_TTL + 120)]
+        keys_patch, find_patch = self._patch_keys_and_lookup(workers)
+        with self._patch_get_connection(), keys_patch, find_patch:
+            self.assertFalse(any_workers_for_queue('default'))
+
+    def test_returns_false_when_no_workers(self):
+        keys_patch, find_patch = self._patch_keys_and_lookup([])
+        with self._patch_get_connection(), keys_patch, find_patch:
+            self.assertFalse(any_workers_for_queue('default'))
+
+    def test_returns_false_when_only_other_queue(self):
+        workers = [_make_worker(name='other', queues=('high',))]
+        keys_patch, find_patch = self._patch_keys_and_lookup(workers)
+        with self._patch_get_connection(), keys_patch, find_patch:
+            self.assertFalse(any_workers_for_queue('default'))
+
+    def test_short_circuits_on_first_live_worker(self):
+        # The first key resolves to a live worker; subsequent keys must not
+        # be fetched.
+        workers = [
+            _make_worker(name='alive', last_heartbeat_age_seconds=10),
+            _make_worker(name='other', last_heartbeat_age_seconds=10),
+        ]
+        keys = [f'rq:worker:{w.name}' for w in workers]
+        by_key = dict(zip(keys, workers))
+        find = MagicMock(side_effect=lambda key, connection=None: by_key.get(key))
+        with self._patch_get_connection(), \
+                patch('utilities.rqworker.Worker.all_keys', return_value=keys), \
+                patch('utilities.rqworker.Worker.find_by_key', find):
+            self.assertTrue(any_workers_for_queue('default'))
+        self.assertEqual(find.call_count, 1)
+
+    def test_skips_missing_workers(self):
+        # find_by_key returning None (stale registry entry pointing to a
+        # vanished hash) must not raise; iteration continues to the next key.
+        live = _make_worker(name='alive', last_heartbeat_age_seconds=10)
+        keys = ['rq:worker:ghost', 'rq:worker:alive']
+        find = MagicMock(side_effect=[None, live])
+        with self._patch_get_connection(), \
+                patch('utilities.rqworker.Worker.all_keys', return_value=keys), \
+                patch('utilities.rqworker.Worker.find_by_key', find):
+            self.assertTrue(any_workers_for_queue('default'))