__init__.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from collections import namedtuple
  2. from decimal import Decimal
  3. from django.core.exceptions import FieldDoesNotExist
  4. from django.db import models
  5. from netaddr import IPAddress, IPNetwork
  6. from ipam.fields import IPAddressField, IPNetworkField
  7. from netbox.registry import registry
  8. ObjectFieldValue = namedtuple('ObjectFieldValue', ('name', 'type', 'weight', 'value'))
  9. class FieldTypes:
  10. FLOAT = 'float'
  11. INTEGER = 'int'
  12. STRING = 'str'
  13. INET = 'inet'
  14. CIDR = 'cidr'
  15. class LookupTypes:
  16. PARTIAL = 'icontains'
  17. EXACT = 'iexact'
  18. STARTSWITH = 'istartswith'
  19. ENDSWITH = 'iendswith'
  20. REGEX = 'iregex'
  21. class SearchIndex:
  22. """
  23. Base class for building search indexes.
  24. Attributes:
  25. model: The model class for which this index is used.
  26. category: The label of the group under which this indexer is categorized (for form field display). If none,
  27. the name of the model's app will be used.
  28. fields: An iterable of two-tuples defining the model fields to be indexed and the weight associated with each.
  29. display_attrs: An iterable of additional object attributes to include when displaying search results.
  30. """
  31. model = None
  32. category = None
  33. fields = ()
  34. display_attrs = ()
  35. @staticmethod
  36. def get_field_type(instance, field_name):
  37. """
  38. Return the data type of the specified model field.
  39. """
  40. field_cls = instance._meta.get_field(field_name).__class__
  41. if issubclass(field_cls, (models.FloatField, models.DecimalField)):
  42. return FieldTypes.FLOAT
  43. if issubclass(field_cls, IPAddressField):
  44. return FieldTypes.INET
  45. if issubclass(field_cls, IPNetworkField):
  46. return FieldTypes.CIDR
  47. if issubclass(field_cls, models.IntegerField):
  48. return FieldTypes.INTEGER
  49. return FieldTypes.STRING
  50. @staticmethod
  51. def get_attr_type(instance, field_name):
  52. """
  53. Return the data type of the specified object attribute.
  54. """
  55. value = getattr(instance, field_name)
  56. if type(value) is str:
  57. return FieldTypes.STRING
  58. if type(value) is int:
  59. return FieldTypes.INTEGER
  60. if type(value) in (float, Decimal):
  61. return FieldTypes.FLOAT
  62. if type(value) is IPNetwork:
  63. return FieldTypes.CIDR
  64. if type(value) is IPAddress:
  65. return FieldTypes.INET
  66. return FieldTypes.STRING
  67. @staticmethod
  68. def get_field_value(instance, field_name):
  69. """
  70. Return the value of the specified model field as a string.
  71. """
  72. return str(getattr(instance, field_name))
  73. @classmethod
  74. def get_category(cls):
  75. return cls.category or cls.model._meta.app_config.verbose_name
  76. @classmethod
  77. def to_cache(cls, instance, custom_fields=None):
  78. """
  79. Return a list of ObjectFieldValue representing the instance fields to be cached.
  80. Args:
  81. instance: The instance being cached.
  82. custom_fields: An iterable of CustomFields to include when caching the instance. If None, all custom fields
  83. defined for the model will be included. (This can also be provided during bulk caching to avoid looking
  84. up the available custom fields for each instance.)
  85. """
  86. values = []
  87. # Capture built-in fields
  88. for name, weight in cls.fields:
  89. try:
  90. type_ = cls.get_field_type(instance, name)
  91. except FieldDoesNotExist:
  92. # Not a concrete field; handle as an object attribute
  93. type_ = cls.get_attr_type(instance, name)
  94. value = cls.get_field_value(instance, name)
  95. if type_ and value:
  96. values.append(
  97. ObjectFieldValue(name, type_, weight, value)
  98. )
  99. # Capture custom fields
  100. if getattr(instance, 'custom_field_data', None):
  101. if custom_fields is None:
  102. custom_fields = instance.custom_fields
  103. for cf in custom_fields:
  104. type_ = cf.search_type
  105. value = instance.custom_field_data.get(cf.name)
  106. weight = cf.search_weight
  107. if type_ and value and weight:
  108. values.append(
  109. ObjectFieldValue(f'cf_{cf.name}', type_, weight, value)
  110. )
  111. return values
  112. def get_indexer(model):
  113. """
  114. Get the SearchIndex class for the given model.
  115. """
  116. label = f'{model._meta.app_label}.{model._meta.model_name}'
  117. return registry['search'][label]
  118. def register_search(cls):
  119. """
  120. Decorator for registering a SearchIndex class.
  121. """
  122. model = cls.model
  123. label = f'{model._meta.app_label}.{model._meta.model_name}'
  124. registry['search'][label] = cls
  125. return cls