diff options
Diffstat (limited to 'webapp/django/db/backends/postgresql')
-rw-r--r-- | webapp/django/db/backends/postgresql/__init__.py | 0 | ||||
-rw-r--r-- | webapp/django/db/backends/postgresql/base.py | 148 | ||||
-rw-r--r-- | webapp/django/db/backends/postgresql/client.py | 17 | ||||
-rw-r--r-- | webapp/django/db/backends/postgresql/creation.py | 38 | ||||
-rw-r--r-- | webapp/django/db/backends/postgresql/introspection.py | 86 | ||||
-rw-r--r-- | webapp/django/db/backends/postgresql/operations.py | 144 | ||||
-rw-r--r-- | webapp/django/db/backends/postgresql/version.py | 18 |
7 files changed, 451 insertions, 0 deletions
diff --git a/webapp/django/db/backends/postgresql/__init__.py b/webapp/django/db/backends/postgresql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/webapp/django/db/backends/postgresql/__init__.py diff --git a/webapp/django/db/backends/postgresql/base.py b/webapp/django/db/backends/postgresql/base.py new file mode 100644 index 0000000000..376d7ba2c8 --- /dev/null +++ b/webapp/django/db/backends/postgresql/base.py @@ -0,0 +1,148 @@ +""" +PostgreSQL database backend for Django. + +Requires psycopg 1: http://initd.org/projects/psycopg1 +""" + +from django.db.backends import * +from django.db.backends.postgresql.client import DatabaseClient +from django.db.backends.postgresql.creation import DatabaseCreation +from django.db.backends.postgresql.introspection import DatabaseIntrospection +from django.db.backends.postgresql.operations import DatabaseOperations +from django.db.backends.postgresql.version import get_version +from django.utils.encoding import smart_str, smart_unicode + +try: + import psycopg as Database +except ImportError, e: + from django.core.exceptions import ImproperlyConfigured + raise ImproperlyConfigured("Error loading psycopg module: %s" % e) + +DatabaseError = Database.DatabaseError +IntegrityError = Database.IntegrityError + +class UnicodeCursorWrapper(object): + """ + A thin wrapper around psycopg cursors that allows them to accept Unicode + strings as params. + + This is necessary because psycopg doesn't apply any DB quoting to + parameters that are Unicode strings. If a param is Unicode, this will + convert it to a bytestring using database client's encoding before passing + it to psycopg. + + All results retrieved from the database are converted into Unicode strings + before being returned to the caller. + """ + def __init__(self, cursor, charset): + self.cursor = cursor + self.charset = charset + + def format_params(self, params): + if isinstance(params, dict): + result = {} + charset = self.charset + for key, value in params.items(): + result[smart_str(key, charset)] = smart_str(value, charset) + return result + else: + return tuple([smart_str(p, self.charset, True) for p in params]) + + def execute(self, sql, params=()): + return self.cursor.execute(smart_str(sql, self.charset), self.format_params(params)) + + def executemany(self, sql, param_list): + new_param_list = [self.format_params(params) for params in param_list] + return self.cursor.executemany(sql, new_param_list) + + def __getattr__(self, attr): + if attr in self.__dict__: + return self.__dict__[attr] + else: + return getattr(self.cursor, attr) + + def __iter__(self): + return iter(self.cursor) + +class DatabaseFeatures(BaseDatabaseFeatures): + uses_savepoints = True + +class DatabaseWrapper(BaseDatabaseWrapper): + operators = { + 'exact': '= %s', + 'iexact': '= UPPER(%s)', + 'contains': 'LIKE %s', + 'icontains': 'LIKE UPPER(%s)', + 'regex': '~ %s', + 'iregex': '~* %s', + 'gt': '> %s', + 'gte': '>= %s', + 'lt': '< %s', + 'lte': '<= %s', + 'startswith': 'LIKE %s', + 'endswith': 'LIKE %s', + 'istartswith': 'LIKE UPPER(%s)', + 'iendswith': 'LIKE UPPER(%s)', + } + + def __init__(self, *args, **kwargs): + super(DatabaseWrapper, self).__init__(*args, **kwargs) + + self.features = DatabaseFeatures() + self.ops = DatabaseOperations() + self.client = DatabaseClient() + self.creation = DatabaseCreation(self) + self.introspection = DatabaseIntrospection(self) + self.validation = BaseDatabaseValidation() + + def _cursor(self, settings): + set_tz = False + if self.connection is None: + set_tz = True + if settings.DATABASE_NAME == '': + from django.core.exceptions import ImproperlyConfigured + raise ImproperlyConfigured("You need to specify DATABASE_NAME in your Django settings file.") + conn_string = "dbname=%s" % settings.DATABASE_NAME + if settings.DATABASE_USER: + conn_string = "user=%s %s" % (settings.DATABASE_USER, conn_string) + if settings.DATABASE_PASSWORD: + conn_string += " password='%s'" % settings.DATABASE_PASSWORD + if settings.DATABASE_HOST: + conn_string += " host=%s" % settings.DATABASE_HOST + if settings.DATABASE_PORT: + conn_string += " port=%s" % settings.DATABASE_PORT + self.connection = Database.connect(conn_string, **self.options) + self.connection.set_isolation_level(1) # make transactions transparent to all cursors + cursor = self.connection.cursor() + if set_tz: + cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) + if not hasattr(self, '_version'): + version = get_version(cursor) + self.__class__._version = version + if version < (8, 0): + # No savepoint support for earlier version of PostgreSQL. + self.features.uses_savepoints = False + cursor.execute("SET client_encoding to 'UNICODE'") + cursor = UnicodeCursorWrapper(cursor, 'utf-8') + return cursor + +def typecast_string(s): + """ + Cast all returned strings to unicode strings. + """ + if not s and not isinstance(s, str): + return s + return smart_unicode(s) + +# Register these custom typecasts, because Django expects dates/times to be +# in Python's native (standard-library) datetime/time format, whereas psycopg +# use mx.DateTime by default. +try: + Database.register_type(Database.new_type((1082,), "DATE", util.typecast_date)) +except AttributeError: + raise Exception("You appear to be using psycopg version 2. Set your DATABASE_ENGINE to 'postgresql_psycopg2' instead of 'postgresql'.") +Database.register_type(Database.new_type((1083,1266), "TIME", util.typecast_time)) +Database.register_type(Database.new_type((1114,1184), "TIMESTAMP", util.typecast_timestamp)) +Database.register_type(Database.new_type((16,), "BOOLEAN", util.typecast_boolean)) +Database.register_type(Database.new_type((1700,), "NUMERIC", util.typecast_decimal)) +Database.register_type(Database.new_type(Database.types[1043].values, 'STRING', typecast_string)) diff --git a/webapp/django/db/backends/postgresql/client.py b/webapp/django/db/backends/postgresql/client.py new file mode 100644 index 0000000000..28daed833a --- /dev/null +++ b/webapp/django/db/backends/postgresql/client.py @@ -0,0 +1,17 @@ +from django.db.backends import BaseDatabaseClient +from django.conf import settings +import os + +class DatabaseClient(BaseDatabaseClient): + def runshell(self): + args = ['psql'] + if settings.DATABASE_USER: + args += ["-U", settings.DATABASE_USER] + if settings.DATABASE_PASSWORD: + args += ["-W"] + if settings.DATABASE_HOST: + args.extend(["-h", settings.DATABASE_HOST]) + if settings.DATABASE_PORT: + args.extend(["-p", str(settings.DATABASE_PORT)]) + args += [settings.DATABASE_NAME] + os.execvp('psql', args) diff --git a/webapp/django/db/backends/postgresql/creation.py b/webapp/django/db/backends/postgresql/creation.py new file mode 100644 index 0000000000..3e537e345e --- /dev/null +++ b/webapp/django/db/backends/postgresql/creation.py @@ -0,0 +1,38 @@ +from django.conf import settings +from django.db.backends.creation import BaseDatabaseCreation + +class DatabaseCreation(BaseDatabaseCreation): + # This dictionary maps Field objects to their associated PostgreSQL column + # types, as strings. Column-type strings can contain format strings; they'll + # be interpolated against the values of Field.__dict__ before being output. + # If a column type is set to None, it won't be included in the output. + data_types = { + 'AutoField': 'serial', + 'BooleanField': 'boolean', + 'CharField': 'varchar(%(max_length)s)', + 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', + 'DateField': 'date', + 'DateTimeField': 'timestamp with time zone', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FileField': 'varchar(%(max_length)s)', + 'FilePathField': 'varchar(%(max_length)s)', + 'FloatField': 'double precision', + 'IntegerField': 'integer', + 'IPAddressField': 'inet', + 'NullBooleanField': 'boolean', + 'OneToOneField': 'integer', + 'PhoneNumberField': 'varchar(20)', + 'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)', + 'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)', + 'SlugField': 'varchar(%(max_length)s)', + 'SmallIntegerField': 'smallint', + 'TextField': 'text', + 'TimeField': 'time', + 'USStateField': 'varchar(2)', + } + + def sql_table_creation_suffix(self): + assert settings.TEST_DATABASE_COLLATION is None, "PostgreSQL does not support collation setting at database creation time." + if settings.TEST_DATABASE_CHARSET: + return "WITH ENCODING '%s'" % settings.TEST_DATABASE_CHARSET + return '' diff --git a/webapp/django/db/backends/postgresql/introspection.py b/webapp/django/db/backends/postgresql/introspection.py new file mode 100644 index 0000000000..7b3ab3bb8a --- /dev/null +++ b/webapp/django/db/backends/postgresql/introspection.py @@ -0,0 +1,86 @@ +from django.db.backends import BaseDatabaseIntrospection + +class DatabaseIntrospection(BaseDatabaseIntrospection): + # Maps type codes to Django Field types. + data_types_reverse = { + 16: 'BooleanField', + 21: 'SmallIntegerField', + 23: 'IntegerField', + 25: 'TextField', + 701: 'FloatField', + 869: 'IPAddressField', + 1043: 'CharField', + 1082: 'DateField', + 1083: 'TimeField', + 1114: 'DateTimeField', + 1184: 'DateTimeField', + 1266: 'TimeField', + 1700: 'DecimalField', + } + + def get_table_list(self, cursor): + "Returns a list of table names in the current database." + cursor.execute(""" + SELECT c.relname + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE c.relkind IN ('r', 'v', '') + AND n.nspname NOT IN ('pg_catalog', 'pg_toast') + AND pg_catalog.pg_table_is_visible(c.oid)""") + return [row[0] for row in cursor.fetchall()] + + def get_table_description(self, cursor, table_name): + "Returns a description of the table, with the DB-API cursor.description interface." + cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) + return cursor.description + + def get_relations(self, cursor, table_name): + """ + Returns a dictionary of {field_index: (field_index_other_table, other_table)} + representing all relationships to the given table. Indexes are 0-based. + """ + cursor.execute(""" + SELECT con.conkey, con.confkey, c2.relname + FROM pg_constraint con, pg_class c1, pg_class c2 + WHERE c1.oid = con.conrelid + AND c2.oid = con.confrelid + AND c1.relname = %s + AND con.contype = 'f'""", [table_name]) + relations = {} + for row in cursor.fetchall(): + try: + # row[0] and row[1] are like "{2}", so strip the curly braces. + relations[int(row[0][1:-1]) - 1] = (int(row[1][1:-1]) - 1, row[2]) + except ValueError: + continue + return relations + + def get_indexes(self, cursor, table_name): + """ + Returns a dictionary of fieldname -> infodict for the given table, + where each infodict is in the format: + {'primary_key': boolean representing whether it's the primary key, + 'unique': boolean representing whether it's a unique index} + """ + # This query retrieves each index on the given table, including the + # first associated field name + cursor.execute(""" + SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary + FROM pg_catalog.pg_class c, pg_catalog.pg_class c2, + pg_catalog.pg_index idx, pg_catalog.pg_attribute attr + WHERE c.oid = idx.indrelid + AND idx.indexrelid = c2.oid + AND attr.attrelid = c.oid + AND attr.attnum = idx.indkey[0] + AND c.relname = %s""", [table_name]) + indexes = {} + for row in cursor.fetchall(): + # row[1] (idx.indkey) is stored in the DB as an array. It comes out as + # a string of space-separated integers. This designates the field + # indexes (1-based) of the fields that have indexes on the table. + # Here, we skip any indexes across multiple fields. + if ' ' in row[1]: + continue + indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]} + return indexes + diff --git a/webapp/django/db/backends/postgresql/operations.py b/webapp/django/db/backends/postgresql/operations.py new file mode 100644 index 0000000000..01cc1fc8b7 --- /dev/null +++ b/webapp/django/db/backends/postgresql/operations.py @@ -0,0 +1,144 @@ +import re + +from django.db.backends import BaseDatabaseOperations + +server_version_re = re.compile(r'PostgreSQL (\d{1,2})\.(\d{1,2})\.?(\d{1,2})?') + +# This DatabaseOperations class lives in here instead of base.py because it's +# used by both the 'postgresql' and 'postgresql_psycopg2' backends. + +class DatabaseOperations(BaseDatabaseOperations): + def __init__(self): + self._postgres_version = None + + def _get_postgres_version(self): + if self._postgres_version is None: + from django.db import connection + cursor = connection.cursor() + cursor.execute("SELECT version()") + version_string = cursor.fetchone()[0] + m = server_version_re.match(version_string) + if not m: + raise Exception('Unable to determine PostgreSQL version from version() function string: %r' % version_string) + self._postgres_version = [int(val) for val in m.groups() if val] + return self._postgres_version + postgres_version = property(_get_postgres_version) + + def date_extract_sql(self, lookup_type, field_name): + # http://www.postgresql.org/docs/8.0/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT + return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name) + + def date_trunc_sql(self, lookup_type, field_name): + # http://www.postgresql.org/docs/8.0/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC + return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) + + def deferrable_sql(self): + return " DEFERRABLE INITIALLY DEFERRED" + + def lookup_cast(self, lookup_type): + lookup = '%s' + + # Cast text lookups to text to allow things like filter(x__contains=4) + if lookup_type in ('iexact', 'contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith'): + lookup = "%s::text" + + # Use UPPER(x) for case-insensitive lookups; it's faster. + if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'): + lookup = 'UPPER(%s)' % lookup + + return lookup + + def field_cast_sql(self, db_type): + if db_type == 'inet': + return 'HOST(%s)' + return '%s' + + def last_insert_id(self, cursor, table_name, pk_name): + cursor.execute("SELECT CURRVAL('\"%s_%s_seq\"')" % (table_name, pk_name)) + return cursor.fetchone()[0] + + def no_limit_value(self): + return None + + def quote_name(self, name): + if name.startswith('"') and name.endswith('"'): + return name # Quoting once is enough. + return '"%s"' % name + + def sql_flush(self, style, tables, sequences): + if tables: + if self.postgres_version[0] >= 8 and self.postgres_version[1] >= 1: + # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* + # in order to be able to truncate tables referenced by a foreign + # key in any other table. The result is a single SQL TRUNCATE + # statement. + sql = ['%s %s;' % \ + (style.SQL_KEYWORD('TRUNCATE'), + style.SQL_FIELD(', '.join([self.quote_name(table) for table in tables])) + )] + else: + # Older versions of Postgres can't do TRUNCATE in a single call, so + # they must use a simple delete. + sql = ['%s %s %s;' % \ + (style.SQL_KEYWORD('DELETE'), + style.SQL_KEYWORD('FROM'), + style.SQL_FIELD(self.quote_name(table)) + ) for table in tables] + + # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements + # to reset sequence indices + for sequence_info in sequences: + table_name = sequence_info['table'] + column_name = sequence_info['column'] + if column_name and len(column_name) > 0: + sequence_name = '%s_%s_seq' % (table_name, column_name) + else: + sequence_name = '%s_id_seq' % table_name + sql.append("%s setval('%s', 1, false);" % \ + (style.SQL_KEYWORD('SELECT'), + style.SQL_FIELD(self.quote_name(sequence_name))) + ) + return sql + else: + return [] + + def sequence_reset_sql(self, style, model_list): + from django.db import models + output = [] + qn = self.quote_name + for model in model_list: + # Use `coalesce` to set the sequence for each model to the max pk value if there are records, + # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true + # if there are records (as the max pk value is already in use), otherwise set it to false. + for f in model._meta.local_fields: + if isinstance(f, models.AutoField): + output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + (style.SQL_KEYWORD('SELECT'), + style.SQL_FIELD(qn('%s_%s_seq' % (model._meta.db_table, f.column))), + style.SQL_FIELD(qn(f.column)), + style.SQL_FIELD(qn(f.column)), + style.SQL_KEYWORD('IS NOT'), + style.SQL_KEYWORD('FROM'), + style.SQL_TABLE(qn(model._meta.db_table)))) + break # Only one AutoField is allowed per model, so don't bother continuing. + for f in model._meta.many_to_many: + output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \ + (style.SQL_KEYWORD('SELECT'), + style.SQL_FIELD(qn('%s_id_seq' % f.m2m_db_table())), + style.SQL_FIELD(qn('id')), + style.SQL_FIELD(qn('id')), + style.SQL_KEYWORD('IS NOT'), + style.SQL_KEYWORD('FROM'), + style.SQL_TABLE(qn(f.m2m_db_table())))) + return output + + def savepoint_create_sql(self, sid): + return "SAVEPOINT %s" % sid + + def savepoint_commit_sql(self, sid): + return "RELEASE SAVEPOINT %s" % sid + + def savepoint_rollback_sql(self, sid): + return "ROLLBACK TO SAVEPOINT %s" % sid + diff --git a/webapp/django/db/backends/postgresql/version.py b/webapp/django/db/backends/postgresql/version.py new file mode 100644 index 0000000000..e14d791b07 --- /dev/null +++ b/webapp/django/db/backends/postgresql/version.py @@ -0,0 +1,18 @@ +""" +Extracts the version of the PostgreSQL server. +""" + +import re + +VERSION_RE = re.compile(r'PostgreSQL (\d+)\.(\d+)\.') + +def get_version(cursor): + """ + Returns a tuple representing the major and minor version number of the + server. For example, (7, 4) or (8, 3). + """ + cursor.execute("SELECT version()") + version = cursor.fetchone()[0] + major, minor = VERSION_RE.search(version).groups() + return int(major), int(minor) + |