lookups.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from django.db.models import IntegerField, Lookup, Transform, lookups
  2. class NetFieldDecoratorMixin:
  3. def process_lhs(self, qn, connection, lhs=None):
  4. lhs = lhs or self.lhs
  5. lhs_string, lhs_params = qn.compile(lhs)
  6. lhs_string = f'TEXT({lhs_string})'
  7. return lhs_string, lhs_params
  8. class IExact(NetFieldDecoratorMixin, lookups.IExact):
  9. def get_rhs_op(self, connection, rhs):
  10. return f'= LOWER({rhs})'
  11. class EndsWith(NetFieldDecoratorMixin, lookups.EndsWith):
  12. pass
  13. class IEndsWith(NetFieldDecoratorMixin, lookups.IEndsWith):
  14. pass
  15. def get_rhs_op(self, connection, rhs):
  16. return f'LIKE LOWER({rhs})'
  17. class StartsWith(NetFieldDecoratorMixin, lookups.StartsWith):
  18. lookup_name = 'startswith'
  19. class IStartsWith(NetFieldDecoratorMixin, lookups.IStartsWith):
  20. pass
  21. def get_rhs_op(self, connection, rhs):
  22. return f'LIKE LOWER({rhs})'
  23. class Regex(NetFieldDecoratorMixin, lookups.Regex):
  24. pass
  25. class IRegex(NetFieldDecoratorMixin, lookups.IRegex):
  26. pass
  27. class NetContainsOrEquals(Lookup):
  28. lookup_name = 'net_contains_or_equals'
  29. def as_sql(self, qn, connection):
  30. lhs, lhs_params = self.process_lhs(qn, connection)
  31. rhs, rhs_params = self.process_rhs(qn, connection)
  32. params = lhs_params + rhs_params
  33. return f'{lhs} >>= {rhs}', params
  34. class NetContains(Lookup):
  35. lookup_name = 'net_contains'
  36. def as_sql(self, qn, connection):
  37. lhs, lhs_params = self.process_lhs(qn, connection)
  38. rhs, rhs_params = self.process_rhs(qn, connection)
  39. params = lhs_params + rhs_params
  40. return f'{lhs} >> {rhs}', params
  41. class NetContained(Lookup):
  42. lookup_name = 'net_contained'
  43. def as_sql(self, qn, connection):
  44. lhs, lhs_params = self.process_lhs(qn, connection)
  45. rhs, rhs_params = self.process_rhs(qn, connection)
  46. params = lhs_params + rhs_params
  47. return f'{lhs} << {rhs}', params
  48. class NetContainedOrEqual(Lookup):
  49. lookup_name = 'net_contained_or_equal'
  50. def as_sql(self, qn, connection):
  51. lhs, lhs_params = self.process_lhs(qn, connection)
  52. rhs, rhs_params = self.process_rhs(qn, connection)
  53. params = lhs_params + rhs_params
  54. return f'{lhs} <<= {rhs}', params
  55. class NetHost(Lookup):
  56. lookup_name = 'net_host'
  57. def as_sql(self, qn, connection):
  58. lhs, lhs_params = self.process_lhs(qn, connection)
  59. rhs, rhs_params = self.process_rhs(qn, connection)
  60. # Query parameters are automatically converted to IPNetwork objects, which are then turned to strings. We need
  61. # to omit the mask portion of the object's string representation to match PostgreSQL's HOST() function.
  62. # Note: params may be tuples (Django 6.0+) or lists (older Django), so convert before mutating.
  63. rhs_params = list(rhs_params)
  64. if rhs_params:
  65. rhs_params[0] = rhs_params[0].split('/')[0]
  66. params = list(lhs_params) + rhs_params
  67. return f'HOST({lhs}) = {rhs}', params
  68. class NetIn(Lookup):
  69. lookup_name = 'net_in'
  70. def get_prep_lookup(self):
  71. # Don't cast the query value to a netaddr object, since it may or may not include a mask.
  72. return self.rhs
  73. def as_sql(self, qn, connection):
  74. lhs = self.process_lhs(qn, connection)[0]
  75. rhs_params = self.process_rhs(qn, connection)[1]
  76. with_mask, without_mask = [], []
  77. for address in rhs_params[0]:
  78. if '/' in address:
  79. with_mask.append(address)
  80. else:
  81. without_mask.append(address)
  82. address_in_clause = self.create_in_clause('{} IN ('.format(lhs), len(with_mask))
  83. host_in_clause = self.create_in_clause('HOST({}) IN ('.format(lhs), len(without_mask))
  84. if with_mask and not without_mask:
  85. return address_in_clause, with_mask
  86. if not with_mask and without_mask:
  87. return host_in_clause, without_mask
  88. in_clause = '({}) OR ({})'.format(address_in_clause, host_in_clause)
  89. with_mask.extend(without_mask)
  90. return in_clause, with_mask
  91. @staticmethod
  92. def create_in_clause(clause_part, max_size):
  93. clause_elements = [clause_part]
  94. for offset in range(0, max_size):
  95. if offset > 0:
  96. clause_elements.append(', ')
  97. clause_elements.append('%s')
  98. clause_elements.append(')')
  99. return ''.join(clause_elements)
  100. class NetHostContained(Lookup):
  101. """
  102. Check for the host portion of an IP address without regard to its mask. This allows us to find e.g. 192.0.2.1/24
  103. when specifying a parent prefix of 192.0.2.0/26.
  104. """
  105. lookup_name = 'net_host_contained'
  106. def as_sql(self, qn, connection):
  107. lhs, lhs_params = self.process_lhs(qn, connection)
  108. rhs, rhs_params = self.process_rhs(qn, connection)
  109. params = lhs_params + rhs_params
  110. return f'CAST(HOST({lhs}) AS INET) <<= {rhs}', params
  111. class NetFamily(Transform):
  112. lookup_name = 'family'
  113. function = 'FAMILY'
  114. @property
  115. def output_field(self):
  116. return IntegerField()
  117. class NetMaskLength(Transform):
  118. function = 'MASKLEN'
  119. lookup_name = 'net_mask_length'
  120. @property
  121. def output_field(self):
  122. return IntegerField()
  123. class Host(Transform):
  124. function = 'HOST'
  125. lookup_name = 'host'
  126. class Inet(Transform):
  127. function = 'INET'
  128. lookup_name = 'inet'