| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- from django.db.models import IntegerField, Lookup, Transform, lookups
- class NetFieldDecoratorMixin(object):
- def process_lhs(self, qn, connection, lhs=None):
- lhs = lhs or self.lhs
- lhs_string, lhs_params = qn.compile(lhs)
- lhs_string = 'TEXT(%s)' % lhs_string
- return lhs_string, lhs_params
- class IExact(NetFieldDecoratorMixin, lookups.IExact):
- def get_rhs_op(self, connection, rhs):
- return '= LOWER(%s)' % rhs
- class EndsWith(NetFieldDecoratorMixin, lookups.EndsWith):
- pass
- class IEndsWith(NetFieldDecoratorMixin, lookups.IEndsWith):
- pass
- def get_rhs_op(self, connection, rhs):
- return 'LIKE LOWER(%s)' % rhs
- class StartsWith(NetFieldDecoratorMixin, lookups.StartsWith):
- lookup_name = 'startswith'
- class IStartsWith(NetFieldDecoratorMixin, lookups.IStartsWith):
- pass
- def get_rhs_op(self, connection, rhs):
- return 'LIKE LOWER(%s)' % rhs
- class Regex(NetFieldDecoratorMixin, lookups.Regex):
- pass
- class IRegex(NetFieldDecoratorMixin, lookups.IRegex):
- pass
- class NetContainsOrEquals(Lookup):
- lookup_name = 'net_contains_or_equals'
- def as_sql(self, qn, connection):
- lhs, lhs_params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- params = lhs_params + rhs_params
- return '%s >>= %s' % (lhs, rhs), params
- class NetContains(Lookup):
- lookup_name = 'net_contains'
- def as_sql(self, qn, connection):
- lhs, lhs_params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- params = lhs_params + rhs_params
- return '%s >> %s' % (lhs, rhs), params
- class NetContained(Lookup):
- lookup_name = 'net_contained'
- def as_sql(self, qn, connection):
- lhs, lhs_params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- params = lhs_params + rhs_params
- return '%s << %s' % (lhs, rhs), params
- class NetContainedOrEqual(Lookup):
- lookup_name = 'net_contained_or_equal'
- def as_sql(self, qn, connection):
- lhs, lhs_params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- params = lhs_params + rhs_params
- return '%s <<= %s' % (lhs, rhs), params
- class NetHost(Lookup):
- lookup_name = 'net_host'
- def as_sql(self, qn, connection):
- lhs, lhs_params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- # Query parameters are automatically converted to IPNetwork objects, which are then turned to strings. We need
- # to omit the mask portion of the object's string representation to match PostgreSQL's HOST() function.
- if rhs_params:
- rhs_params[0] = rhs_params[0].split('/')[0]
- params = lhs_params + rhs_params
- return 'HOST(%s) = %s' % (lhs, rhs), params
- class NetIn(Lookup):
- lookup_name = 'net_in'
- def get_prep_lookup(self):
- # Don't cast the query value to a netaddr object, since it may or may not include a mask.
- return self.rhs
- def as_sql(self, qn, connection):
- lhs, lhs_params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- with_mask, without_mask = [], []
- for address in rhs_params[0]:
- if '/' in address:
- with_mask.append(address)
- else:
- without_mask.append(address)
- address_in_clause = self.create_in_clause('{} IN ('.format(lhs), len(with_mask))
- host_in_clause = self.create_in_clause('HOST({}) IN ('.format(lhs), len(without_mask))
- if with_mask and not without_mask:
- return address_in_clause, with_mask
- elif not with_mask and without_mask:
- return host_in_clause, without_mask
- in_clause = '({}) OR ({})'.format(address_in_clause, host_in_clause)
- with_mask.extend(without_mask)
- return in_clause, with_mask
- @staticmethod
- def create_in_clause(clause_part, max_size):
- clause_elements = [clause_part]
- for offset in range(0, max_size):
- if offset > 0:
- clause_elements.append(', ')
- clause_elements.append('%s')
- clause_elements.append(')')
- return ''.join(clause_elements)
- class NetHostContained(Lookup):
- """
- Check for the host portion of an IP address without regard to its mask. This allows us to find e.g. 192.0.2.1/24
- when specifying a parent prefix of 192.0.2.0/26.
- """
- lookup_name = 'net_host_contained'
- def as_sql(self, qn, connection):
- lhs, lhs_params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- params = lhs_params + rhs_params
- return 'CAST(HOST(%s) AS INET) << %s' % (lhs, rhs), params
- class NetMaskLength(Transform):
- lookup_name = 'net_mask_length'
- function = 'MASKLEN'
- @property
- def output_field(self):
- return IntegerField()
|