Explorar o código

Closes #21952: Improve robustness of RQ worker check (#22234)

Jeremy Stretch hai 16 horas
pai
achega
a8f609bf21

+ 3 - 2
netbox/core/views.py

@@ -14,7 +14,7 @@ from django.shortcuts import get_object_or_404, redirect, render
 from django.urls import reverse
 from django.utils.translation import gettext_lazy as _
 from django.views.generic import View
-from django_rq.queues import get_connection, get_queue_by_index, get_redis_connection
+from django_rq.queues import get_queue_by_index, get_redis_connection
 from django_rq.settings import get_queues_list, get_queues_map
 from django_rq.utils import get_statistics
 from rq.exceptions import NoSuchJobError
@@ -55,6 +55,7 @@ from utilities.forms import ConfirmationForm
 from utilities.htmx import htmx_partial
 from utilities.json import ConfigJSONEncoder
 from utilities.query import count_related
+from utilities.rqworker import get_all_workers
 from utilities.views import (
     ContentTypePermissionRequiredMixin,
     GetRelatedModelsMixin,
@@ -707,7 +708,7 @@ class SystemView(UserPassesTestMixin, View):
             'postgresql_version': psql_version,
             'database_name': db_name,
             'database_size': db_size,
-            'rq_worker_count': Worker.count(get_connection('default')),
+            'rq_worker_count': len(get_all_workers()),
         }
 
     def _get_object_counts(self):

+ 2 - 3
netbox/extras/api/views.py

@@ -1,6 +1,5 @@
 from django.http import Http404
 from django.shortcuts import get_object_or_404
-from django_rq.queues import get_connection
 from drf_spectacular.utils import extend_schema, extend_schema_view
 from rest_framework import status
 from rest_framework.decorators import action
@@ -11,7 +10,6 @@ from rest_framework.renderers import JSONRenderer
 from rest_framework.response import Response
 from rest_framework.routers import APIRootView
 from rest_framework.viewsets import ModelViewSet
-from rq import Worker
 
 from extras import filtersets
 from extras.jobs import ScriptJob
@@ -24,6 +22,7 @@ from netbox.api.viewsets import BaseViewSet, NetBoxModelViewSet
 from netbox.api.viewsets.mixins import ObjectValidationMixin
 from utilities.exceptions import RQWorkerNotRunningException
 from utilities.request import copy_safe_request
+from utilities.rqworker import any_workers_for_queue
 
 from . import serializers
 from .mixins import ConfigTemplateRenderMixin
@@ -326,7 +325,7 @@ class ScriptViewSet(ModelViewSet):
         )
 
         # Check that at least one RQ worker is running
-        if not Worker.count(get_connection('default')):
+        if not any_workers_for_queue('default'):
             raise RQWorkerNotRunningException()
 
         if input_serializer.is_valid():

+ 3 - 3
netbox/extras/tests/test_views.py

@@ -1088,7 +1088,7 @@ class ScriptValidationErrorTestCase(TestCase):
     def test_script_validation_error_displays_message(self):
         url = reverse('extras:script', kwargs={'pk': self.script.pk})
 
-        with patch('extras.views.get_workers_for_queue', return_value=['worker']):
+        with patch('extras.views.any_workers_for_queue', return_value=True):
             response = self.client.post(url, {'debug_mode': 'true', '_commit': 'true'})
 
         self.assertEqual(response.status_code, 200)
@@ -1113,7 +1113,7 @@ class ScriptValidationErrorTestCase(TestCase):
 
         with patch.object(Script, 'python_class', new_callable=PropertyMock) as mock_python_class:
             mock_python_class.return_value = FieldsetScript
-            with patch('extras.views.get_workers_for_queue', return_value=['worker']):
+            with patch('extras.views.any_workers_for_queue', return_value=True):
                 response = self.client.post(url, {'required_field': '5', '_commit': 'true'})
 
         self.assertEqual(response.status_code, 200)
@@ -1156,7 +1156,7 @@ class ScriptDefaultValuesTestCase(TestCase):
     def test_default_values_are_used(self):
         url = reverse('extras:script', kwargs={'pk': self.script.pk})
 
-        with patch('extras.views.get_workers_for_queue', return_value=['worker']):
+        with patch('extras.views.any_workers_for_queue', return_value=True):
             with patch('extras.jobs.ScriptJob.enqueue') as mock_enqueue:
                 mock_enqueue.return_value.pk = 1
                 self.client.post(url, {})

+ 2 - 2
netbox/extras/views.py

@@ -38,7 +38,7 @@ from utilities.paginator import EnhancedPaginator, get_paginate_count
 from utilities.query import count_related
 from utilities.querydict import normalize_querydict
 from utilities.request import copy_safe_request
-from utilities.rqworker import get_workers_for_queue
+from utilities.rqworker import any_workers_for_queue
 from utilities.templatetags.builtins.filters import render_markdown
 from utilities.views import ContentTypePermissionRequiredMixin, get_action_url, register_model_view
 from virtualization.models import VirtualMachine
@@ -1729,7 +1729,7 @@ class ScriptView(BaseScriptView):
         form = script_class.as_form(post_data, request.FILES)
 
         # Allow execution only if RQ worker process is running
-        if not get_workers_for_queue('default'):
+        if not any_workers_for_queue('default'):
             messages.error(request, _("Unable to run script: RQ worker process not running."))
         elif form.is_valid():
             ScriptJob = import_string("extras.jobs.ScriptJob")

+ 2 - 3
netbox/netbox/api/views.py

@@ -2,19 +2,18 @@ import platform
 
 from django import __version__ as DJANGO_VERSION
 from django.conf import settings
-from django_rq.queues import get_connection
 from drf_spectacular.types import OpenApiTypes
 from drf_spectacular.utils import extend_schema
 from rest_framework.permissions import IsAuthenticated
 from rest_framework.response import Response
 from rest_framework.reverse import reverse
 from rest_framework.views import APIView
-from rq.worker import Worker
 
 from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
 from netbox.plugins.utils import get_installed_plugins
 from users.api.serializers import UserSerializer
 from utilities.apps import get_installed_apps
+from utilities.rqworker import get_all_workers
 
 
 class APIRootView(APIView):
@@ -62,7 +61,7 @@ class StatusView(APIView):
             'netbox-full-version': settings.RELEASE.full_version,
             'plugins': get_installed_plugins(),
             'python-version': platform.python_version(),
-            'rq-workers-running': Worker.count(get_connection('default')),
+            'rq-workers-running': len(get_all_workers()),
         })
 
 

+ 7 - 0
netbox/netbox/settings.py

@@ -176,6 +176,13 @@ REMOTE_AUTH_USER_LAST_NAME = getattr(configuration, 'REMOTE_AUTH_USER_LAST_NAME'
 # Required by extras/migrations/0109_script_models.py
 REPORTS_ROOT = getattr(configuration, 'REPORTS_ROOT', os.path.join(BASE_DIR, 'reports')).rstrip('/')
 RQ = getattr(configuration, 'RQ', {})
+if 'WORKER_CLASS' in RQ and RQ['WORKER_CLASS'] != 'utilities.rqworker.NetBoxRQWorker':
+    warnings.warn(
+        f"RQ['WORKER_CLASS'] is set to {RQ['WORKER_CLASS']!r}; NetBoxRQWorker's self-healing heartbeat "
+        f"logic will not be applied. Workers may not automatically recover from a Redis outage."
+    )
+else:
+    RQ.setdefault('WORKER_CLASS', 'utilities.rqworker.NetBoxRQWorker')
 RQ_DEFAULT_TIMEOUT = getattr(configuration, 'RQ_DEFAULT_TIMEOUT', 300)
 RQ_RETRY_INTERVAL = getattr(configuration, 'RQ_RETRY_INTERVAL', 60)
 RQ_RETRY_MAX = getattr(configuration, 'RQ_RETRY_MAX', 0)

+ 0 - 1
netbox/templates/core/system.html

@@ -94,7 +94,6 @@
               <th scope="row">{% trans "RQ workers" %}</th>
               <td>
                 <a href="{% url 'core:background_queue_list' %}">{{ stats.rq_worker_count }}</a>
-                ({% trans "default queue" %})
               </td>
             </tr>
             <tr>

+ 90 - 2
netbox/utilities/rqworker.py

@@ -1,15 +1,51 @@
+import logging
+
 from django_rq.queues import get_connection
 from rq import Retry, Worker
+from rq.worker_registration import REDIS_WORKER_KEYS
+from rq.worker_registration import register as register_worker
 
 from netbox.config import get_config
 from netbox.constants import RQ_QUEUE_DEFAULT
 
 __all__ = (
+    'NetBoxRQWorker',
+    'any_workers_for_queue',
+    'get_all_workers',
     'get_queue_for_model',
     'get_rq_retry',
     'get_workers_for_queue',
 )
 
+logger = logging.getLogger('netbox.rqworker')
+
+
+class NetBoxRQWorker(Worker):
+    """
+    RQ worker subclass which self-heals its registration. If the worker's
+    registration is missing from Redis (e.g. because the tasks Redis database
+    was lost and rebuilt while the worker was running), the next heartbeat
+    will re-register the worker so that Worker.all() / Worker.find_by_key()
+    can locate it again.
+    """
+
+    def heartbeat(self, *args, **kwargs):
+        try:
+            if not self.connection.sismember(REDIS_WORKER_KEYS, self.key):
+                logger.warning(f"Worker {self.name} not found in registry; re-registering.")
+                # If the worker hash still exists (partial Redis data loss),
+                # register_birth() would raise because rq treats an existing,
+                # non-dead hash as an active worker. Re-add to the registry
+                # sets directly in that case; the heartbeat below will refresh
+                # the hash TTL.
+                if self.connection.exists(self.key) and not self.connection.hexists(self.key, 'death'):
+                    register_worker(self, self.connection)
+                else:
+                    self.register_birth()
+        except Exception:
+            logger.exception("Failed to verify worker registration.")
+        super().heartbeat(*args, **kwargs)
+
 
 def get_queue_for_model(model):
     """
@@ -18,11 +54,63 @@ def get_queue_for_model(model):
     return get_config().QUEUE_MAPPINGS.get(model, RQ_QUEUE_DEFAULT)
 
 
+def _is_live_worker(worker, queue_name):
+    """
+    Return True if the given Worker is currently servicing queue_name.
+
+    Liveness itself is enforced by RQ: Worker.all() / Worker.find_by_key()
+    only return workers whose Redis hash still exists, and RQ resets that
+    hash's expiry to (worker_ttl + 60s) on every heartbeat. So any worker
+    returned by RQ has heartbeat'd within its configured TTL -- we only need
+    to confirm it's listening on the requested queue. (Reconstructing
+    worker_ttl ourselves would be unsafe: RQ does not persist worker_ttl in
+    the hash, so a worker started with a non-default --worker-ttl is
+    reconstructed with DEFAULT_WORKER_TTL regardless of its real TTL.)
+    """
+    return queue_name in worker.queue_names()
+
+
 def get_workers_for_queue(queue_name):
     """
-    Returns True if a worker process is currently servicing the specified queue.
+    Return the number of live workers currently servicing the given queue.
+    """
+    connection = get_connection(queue_name)
+    return sum(
+        1 for worker in Worker.all(connection=connection)
+        if _is_live_worker(worker, queue_name)
+    )
+
+
+def get_all_workers():
+    """
+    Return the set of worker names currently registered on the tasks Redis
+    connection, regardless of which queue(s) each worker is servicing. Stale
+    registrations (workers whose Redis hash has expired) are filtered out by
+    RQ via Worker.all() -- see _is_live_worker() for details.
+
+    Used for system-wide worker counts (dashboard, status API), where the
+    intent is "are any RQ workers running" rather than "are workers handling
+    a specific queue."
+    """
+    connection = get_connection(RQ_QUEUE_DEFAULT)
+    return {worker.name for worker in Worker.all(connection=connection)}
+
+
+def any_workers_for_queue(queue_name):
+    """
+    Return True if at least one live worker is currently servicing the given
+    queue. Cheaper than get_workers_for_queue() when only a liveness check is
+    needed: workers are fetched one at a time and iteration stops at the first
+    live match.
     """
-    return Worker.count(get_connection(queue_name))
+    connection = get_connection(queue_name)
+    for key in Worker.all_keys(connection=connection):
+        worker = Worker.find_by_key(key, connection=connection)
+        if worker is None:
+            continue
+        if _is_live_worker(worker, queue_name):
+            return True
+    return False
 
 
 def get_rq_retry():

+ 252 - 0
netbox/utilities/tests/test_rqworker.py

@@ -0,0 +1,252 @@
+from unittest.mock import MagicMock, patch
+
+from django.test import TestCase
+
+from utilities.rqworker import (
+    NetBoxRQWorker,
+    any_workers_for_queue,
+    get_all_workers,
+    get_workers_for_queue,
+)
+
+
+def _make_worker(name='worker-1', queues=('default',)):
+    """
+    Build a MagicMock that mimics the rq.Worker attributes consumed by
+    get_workers_for_queue() / any_workers_for_queue().
+
+    Heartbeat freshness is intentionally not modeled here: liveness is
+    enforced by RQ itself (Worker.all() / find_by_key() only return workers
+    whose Redis hash has not expired), so any worker reaching our code is
+    already known-fresh.
+    """
+    worker = MagicMock()
+    worker.name = name
+    worker.queue_names.return_value = list(queues)
+    return worker
+
+
+class NetBoxRQWorkerHeartbeatTestCase(TestCase):
+    """
+    The overridden heartbeat() must call register_birth() iff the worker is
+    missing from the rq:workers registry set, and must always invoke
+    super().heartbeat().
+    """
+
+    def _make_subject(self, is_member, hash_exists=False, marked_dead=False):
+        worker = NetBoxRQWorker.__new__(NetBoxRQWorker)
+        worker.name = 'test-worker'
+        worker.connection = MagicMock()
+        worker.connection.sismember.return_value = is_member
+        worker.connection.exists.return_value = hash_exists
+        worker.connection.hexists.return_value = marked_dead
+        worker.register_birth = MagicMock()
+        worker.log = MagicMock()
+        return worker
+
+    def test_heartbeat_skips_register_when_present(self):
+        worker = self._make_subject(is_member=True)
+        with patch('rq.Worker.heartbeat') as super_heartbeat:
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_not_called()
+        super_heartbeat.assert_called_once()
+
+    def test_heartbeat_calls_register_birth_when_hash_missing(self):
+        # Full data loss: set membership and hash both gone.
+        worker = self._make_subject(is_member=False, hash_exists=False)
+        with patch('rq.Worker.heartbeat') as super_heartbeat, \
+                patch('utilities.rqworker.register_worker') as register_set:
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_called_once()
+        register_set.assert_not_called()
+        super_heartbeat.assert_called_once()
+
+    def test_heartbeat_readds_to_set_when_hash_survives(self):
+        # Partial data loss: hash present (and not dead), set membership gone.
+        # register_birth() would raise here; we must re-add to the set instead.
+        worker = self._make_subject(is_member=False, hash_exists=True, marked_dead=False)
+        with patch('rq.Worker.heartbeat') as super_heartbeat, \
+                patch('utilities.rqworker.register_worker') as register_set:
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_not_called()
+        register_set.assert_called_once_with(worker, worker.connection)
+        # Liveness is gated on the 'death' hash field specifically; pin the
+        # field name so a typo can't silently fall through to register_birth().
+        worker.connection.hexists.assert_called_with(worker.key, 'death')
+        super_heartbeat.assert_called_once()
+
+    def test_heartbeat_calls_register_birth_when_hash_marked_dead(self):
+        # Hash exists but is marked dead -- treat as full recreate.
+        worker = self._make_subject(is_member=False, hash_exists=True, marked_dead=True)
+        with patch('rq.Worker.heartbeat') as super_heartbeat, \
+                patch('utilities.rqworker.register_worker') as register_set:
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_called_once()
+        register_set.assert_not_called()
+        super_heartbeat.assert_called_once()
+
+    def test_registration_check_exception_still_delegates_to_super_heartbeat(self):
+        # A Redis failure in the registration-check branch (sismember) must
+        # not abort the heartbeat; the parent heartbeat must still be invoked
+        # (whether it then succeeds against a degraded Redis is rq's concern,
+        # not ours -- we patch it out here to isolate our wrapper's behavior).
+        worker = self._make_subject(is_member=False)
+        worker.connection.sismember.side_effect = RuntimeError('redis down')
+        with patch('rq.Worker.heartbeat') as super_heartbeat:
+            # Must not raise
+            NetBoxRQWorker.heartbeat(worker)
+        worker.register_birth.assert_not_called()
+        super_heartbeat.assert_called_once()
+
+
+class GetWorkersForQueueTestCase(TestCase):
+    """
+    get_workers_for_queue() must:
+      * count workers servicing the queue (RQ filters by liveness for us)
+      * exclude workers not listening on the requested queue
+      * return 0 when no workers exist
+    """
+
+    def _patch_worker_all(self, workers):
+        return patch('utilities.rqworker.Worker.all', return_value=workers)
+
+    def _patch_get_connection(self):
+        return patch('utilities.rqworker.get_connection', return_value=MagicMock())
+
+    def test_returns_worker_for_queue(self):
+        workers = [_make_worker(name='alive')]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, 1)
+
+    def test_excludes_worker_for_other_queue(self):
+        workers = [_make_worker(name='other', queues=('high',))]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, 0)
+
+    def test_returns_zero_when_no_workers(self):
+        with self._patch_get_connection(), self._patch_worker_all([]):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, 0)
+
+    def test_includes_worker_listening_on_multiple_queues(self):
+        workers = [_make_worker(name='multi', queues=('high', 'default', 'low'))]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, 1)
+
+    def test_includes_worker_with_custom_ttl(self):
+        # A worker started with --worker-ttl != default is reconstructed by RQ with the default TTL (RQ does not persist
+        # worker_ttl in the worker hash). The fact that Worker.all() returned the worker at all is RQ's confirmation
+        # that the hash hasn't expired -- so we must include it regardless of how stale its heartbeat would look
+        # measured against the default TTL.
+        worker = _make_worker(name='long-ttl')
+        worker.worker_ttl = 420  # rq's DEFAULT_WORKER_TTL, what find_by_key would produce
+        worker.last_heartbeat = None  # heartbeat-derived freshness must not gate inclusion
+        with self._patch_get_connection(), self._patch_worker_all([worker]):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, 1)
+
+    def test_filters_to_queue_in_mixed_set(self):
+        workers = [
+            _make_worker(name='default-worker'),
+            _make_worker(name='high-worker', queues=('high',)),
+        ]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_workers_for_queue('default')
+        self.assertEqual(result, 1)
+
+
+class AnyWorkersForQueueTestCase(TestCase):
+    """
+    any_workers_for_queue() must apply the same queue filter as
+    get_workers_for_queue(), but short-circuit on the first live match.
+    """
+
+    def _patch_keys_and_lookup(self, workers):
+        keys = [f'rq:worker:{w.name}' for w in workers]
+        by_key = dict(zip(keys, workers))
+        return (
+            patch('utilities.rqworker.Worker.all_keys', return_value=keys),
+            patch('utilities.rqworker.Worker.find_by_key', side_effect=lambda key, connection=None: by_key.get(key)),
+        )
+
+    def _patch_get_connection(self):
+        return patch('utilities.rqworker.get_connection', return_value=MagicMock())
+
+    def test_returns_true_when_worker_present(self):
+        workers = [_make_worker(name='alive')]
+        keys_patch, find_patch = self._patch_keys_and_lookup(workers)
+        with self._patch_get_connection(), keys_patch, find_patch:
+            self.assertTrue(any_workers_for_queue('default'))
+
+    def test_returns_false_when_no_workers(self):
+        keys_patch, find_patch = self._patch_keys_and_lookup([])
+        with self._patch_get_connection(), keys_patch, find_patch:
+            self.assertFalse(any_workers_for_queue('default'))
+
+    def test_returns_false_when_only_other_queue(self):
+        workers = [_make_worker(name='other', queues=('high',))]
+        keys_patch, find_patch = self._patch_keys_and_lookup(workers)
+        with self._patch_get_connection(), keys_patch, find_patch:
+            self.assertFalse(any_workers_for_queue('default'))
+
+    def test_short_circuits_on_first_live_worker(self):
+        # The first key resolves to a live worker; subsequent keys must not
+        # be fetched.
+        workers = [
+            _make_worker(name='alive'),
+            _make_worker(name='other'),
+        ]
+        keys = [f'rq:worker:{w.name}' for w in workers]
+        by_key = dict(zip(keys, workers))
+        find = MagicMock(side_effect=lambda key, connection=None: by_key.get(key))
+        with self._patch_get_connection(), \
+                patch('utilities.rqworker.Worker.all_keys', return_value=keys), \
+                patch('utilities.rqworker.Worker.find_by_key', find):
+            self.assertTrue(any_workers_for_queue('default'))
+        self.assertEqual(find.call_count, 1)
+
+    def test_skips_missing_workers(self):
+        # find_by_key returning None (stale registry entry pointing to a
+        # vanished hash) must not raise; iteration continues to the next key.
+        live = _make_worker(name='alive')
+        keys = ['rq:worker:ghost', 'rq:worker:alive']
+        find = MagicMock(side_effect=[None, live])
+        with self._patch_get_connection(), \
+                patch('utilities.rqworker.Worker.all_keys', return_value=keys), \
+                patch('utilities.rqworker.Worker.find_by_key', find):
+            self.assertTrue(any_workers_for_queue('default'))
+
+
+class GetAllWorkersTestCase(TestCase):
+    """
+    get_all_workers() must return all live workers regardless of which queue
+    they service. This preserves the queue-agnostic semantics of the
+    dashboard / status API counters that previously used
+    Worker.count(get_connection('default')).
+    """
+
+    def _patch_worker_all(self, workers):
+        return patch('utilities.rqworker.Worker.all', return_value=workers)
+
+    def _patch_get_connection(self):
+        return patch('utilities.rqworker.get_connection', return_value=MagicMock())
+
+    def test_returns_workers_across_all_queues(self):
+        # Workers on non-default queues must still be counted -- the prior
+        # contract (Worker.count(connection)) was queue-agnostic.
+        workers = [
+            _make_worker(name='default-worker'),
+            _make_worker(name='high-worker', queues=('high',)),
+            _make_worker(name='low-worker', queues=('low',)),
+        ]
+        with self._patch_get_connection(), self._patch_worker_all(workers):
+            result = get_all_workers()
+        self.assertEqual(result, {'default-worker', 'high-worker', 'low-worker'})
+
+    def test_returns_empty_when_no_workers(self):
+        with self._patch_get_connection(), self._patch_worker_all([]):
+            result = get_all_workers()
+        self.assertEqual(result, set())