lookups.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from django.db.models import IntegerField, Lookup, Transform, lookups
  2. class NetFieldDecoratorMixin(object):
  3. def process_lhs(self, qn, connection, lhs=None):
  4. lhs = lhs or self.lhs
  5. lhs_string, lhs_params = qn.compile(lhs)
  6. lhs_string = 'TEXT(%s)' % lhs_string
  7. return lhs_string, lhs_params
  8. class IExact(NetFieldDecoratorMixin, lookups.IExact):
  9. def get_rhs_op(self, connection, rhs):
  10. return '= LOWER(%s)' % rhs
  11. class EndsWith(NetFieldDecoratorMixin, lookups.EndsWith):
  12. pass
  13. class IEndsWith(NetFieldDecoratorMixin, lookups.IEndsWith):
  14. pass
  15. def get_rhs_op(self, connection, rhs):
  16. return 'LIKE LOWER(%s)' % rhs
  17. class StartsWith(NetFieldDecoratorMixin, lookups.StartsWith):
  18. lookup_name = 'startswith'
  19. class IStartsWith(NetFieldDecoratorMixin, lookups.IStartsWith):
  20. pass
  21. def get_rhs_op(self, connection, rhs):
  22. return 'LIKE LOWER(%s)' % rhs
  23. class Regex(NetFieldDecoratorMixin, lookups.Regex):
  24. pass
  25. class IRegex(NetFieldDecoratorMixin, lookups.IRegex):
  26. pass
  27. class NetContainsOrEquals(Lookup):
  28. lookup_name = 'net_contains_or_equals'
  29. def as_sql(self, qn, connection):
  30. lhs, lhs_params = self.process_lhs(qn, connection)
  31. rhs, rhs_params = self.process_rhs(qn, connection)
  32. params = lhs_params + rhs_params
  33. return '%s >>= %s' % (lhs, rhs), params
  34. class NetContains(Lookup):
  35. lookup_name = 'net_contains'
  36. def as_sql(self, qn, connection):
  37. lhs, lhs_params = self.process_lhs(qn, connection)
  38. rhs, rhs_params = self.process_rhs(qn, connection)
  39. params = lhs_params + rhs_params
  40. return '%s >> %s' % (lhs, rhs), params
  41. class NetContained(Lookup):
  42. lookup_name = 'net_contained'
  43. def as_sql(self, qn, connection):
  44. lhs, lhs_params = self.process_lhs(qn, connection)
  45. rhs, rhs_params = self.process_rhs(qn, connection)
  46. params = lhs_params + rhs_params
  47. return '%s << %s' % (lhs, rhs), params
  48. class NetContainedOrEqual(Lookup):
  49. lookup_name = 'net_contained_or_equal'
  50. def as_sql(self, qn, connection):
  51. lhs, lhs_params = self.process_lhs(qn, connection)
  52. rhs, rhs_params = self.process_rhs(qn, connection)
  53. params = lhs_params + rhs_params
  54. return '%s <<= %s' % (lhs, rhs), params
  55. class NetHost(Lookup):
  56. lookup_name = 'net_host'
  57. def as_sql(self, qn, connection):
  58. lhs, lhs_params = self.process_lhs(qn, connection)
  59. rhs, rhs_params = self.process_rhs(qn, connection)
  60. # Query parameters are automatically converted to IPNetwork objects, which are then turned to strings. We need
  61. # to omit the mask portion of the object's string representation to match PostgreSQL's HOST() function.
  62. if rhs_params:
  63. rhs_params[0] = rhs_params[0].split('/')[0]
  64. params = lhs_params + rhs_params
  65. return 'HOST(%s) = %s' % (lhs, rhs), params
  66. class NetIn(Lookup):
  67. lookup_name = 'net_in'
  68. def get_prep_lookup(self):
  69. # Don't cast the query value to a netaddr object, since it may or may not include a mask.
  70. return self.rhs
  71. def as_sql(self, qn, connection):
  72. lhs, lhs_params = self.process_lhs(qn, connection)
  73. rhs, rhs_params = self.process_rhs(qn, connection)
  74. with_mask, without_mask = [], []
  75. for address in rhs_params[0]:
  76. if '/' in address:
  77. with_mask.append(address)
  78. else:
  79. without_mask.append(address)
  80. address_in_clause = self.create_in_clause('{} IN ('.format(lhs), len(with_mask))
  81. host_in_clause = self.create_in_clause('HOST({}) IN ('.format(lhs), len(without_mask))
  82. if with_mask and not without_mask:
  83. return address_in_clause, with_mask
  84. elif not with_mask and without_mask:
  85. return host_in_clause, without_mask
  86. in_clause = '({}) OR ({})'.format(address_in_clause, host_in_clause)
  87. with_mask.extend(without_mask)
  88. return in_clause, with_mask
  89. @staticmethod
  90. def create_in_clause(clause_part, max_size):
  91. clause_elements = [clause_part]
  92. for offset in range(0, max_size):
  93. if offset > 0:
  94. clause_elements.append(', ')
  95. clause_elements.append('%s')
  96. clause_elements.append(')')
  97. return ''.join(clause_elements)
  98. class NetHostContained(Lookup):
  99. """
  100. Check for the host portion of an IP address without regard to its mask. This allows us to find e.g. 192.0.2.1/24
  101. when specifying a parent prefix of 192.0.2.0/26.
  102. """
  103. lookup_name = 'net_host_contained'
  104. def as_sql(self, qn, connection):
  105. lhs, lhs_params = self.process_lhs(qn, connection)
  106. rhs, rhs_params = self.process_rhs(qn, connection)
  107. params = lhs_params + rhs_params
  108. return 'CAST(HOST(%s) AS INET) << %s' % (lhs, rhs), params
  109. class NetMaskLength(Transform):
  110. lookup_name = 'net_mask_length'
  111. function = 'MASKLEN'
  112. @property
  113. def output_field(self):
  114. return IntegerField()