summaryrefslogtreecommitdiffstats
path: root/webapp/django/test/testcases.py
diff options
context:
space:
mode:
Diffstat (limited to 'webapp/django/test/testcases.py')
-rw-r--r--webapp/django/test/testcases.py356
1 files changed, 356 insertions, 0 deletions
diff --git a/webapp/django/test/testcases.py b/webapp/django/test/testcases.py
new file mode 100644
index 0000000000..dcab078553
--- /dev/null
+++ b/webapp/django/test/testcases.py
@@ -0,0 +1,356 @@
+import re
+import unittest
+from urlparse import urlsplit, urlunsplit
+from xml.dom.minidom import parseString, Node
+
+from django.conf import settings
+from django.core import mail
+from django.core.management import call_command
+from django.core.urlresolvers import clear_url_caches
+from django.db import transaction
+from django.http import QueryDict
+from django.test import _doctest as doctest
+from django.test.client import Client
+from django.utils import simplejson
+
+normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
+
+def to_list(value):
+ """
+ Puts value into a list if it's not already one.
+ Returns an empty list if value is None.
+ """
+ if value is None:
+ value = []
+ elif not isinstance(value, list):
+ value = [value]
+ return value
+
+
+class OutputChecker(doctest.OutputChecker):
+ def check_output(self, want, got, optionflags):
+ "The entry method for doctest output checking. Defers to a sequence of child checkers"
+ checks = (self.check_output_default,
+ self.check_output_long,
+ self.check_output_xml,
+ self.check_output_json)
+ for check in checks:
+ if check(want, got, optionflags):
+ return True
+ return False
+
+ def check_output_default(self, want, got, optionflags):
+ "The default comparator provided by doctest - not perfect, but good for most purposes"
+ return doctest.OutputChecker.check_output(self, want, got, optionflags)
+
+ def check_output_long(self, want, got, optionflags):
+ """Doctest does an exact string comparison of output, which means long
+ integers aren't equal to normal integers ("22L" vs. "22"). The
+ following code normalizes long integers so that they equal normal
+ integers.
+ """
+ return normalize_long_ints(want) == normalize_long_ints(got)
+
+ def check_output_xml(self, want, got, optionsflags):
+ """Tries to do a 'xml-comparision' of want and got. Plain string
+ comparision doesn't always work because, for example, attribute
+ ordering should not be important.
+
+ Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
+ """
+ _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
+ def norm_whitespace(v):
+ return _norm_whitespace_re.sub(' ', v)
+
+ def child_text(element):
+ return ''.join([c.data for c in element.childNodes
+ if c.nodeType == Node.TEXT_NODE])
+
+ def children(element):
+ return [c for c in element.childNodes
+ if c.nodeType == Node.ELEMENT_NODE]
+
+ def norm_child_text(element):
+ return norm_whitespace(child_text(element))
+
+ def attrs_dict(element):
+ return dict(element.attributes.items())
+
+ def check_element(want_element, got_element):
+ if want_element.tagName != got_element.tagName:
+ return False
+ if norm_child_text(want_element) != norm_child_text(got_element):
+ return False
+ if attrs_dict(want_element) != attrs_dict(got_element):
+ return False
+ want_children = children(want_element)
+ got_children = children(got_element)
+ if len(want_children) != len(got_children):
+ return False
+ for want, got in zip(want_children, got_children):
+ if not check_element(want, got):
+ return False
+ return True
+
+ want, got = self._strip_quotes(want, got)
+ want = want.replace('\\n','\n')
+ got = got.replace('\\n','\n')
+
+ # If the string is not a complete xml document, we may need to add a
+ # root element. This allow us to compare fragments, like "<foo/><bar/>"
+ if not want.startswith('<?xml'):
+ wrapper = '<root>%s</root>'
+ want = wrapper % want
+ got = wrapper % got
+
+ # Parse the want and got strings, and compare the parsings.
+ try:
+ want_root = parseString(want).firstChild
+ got_root = parseString(got).firstChild
+ except:
+ return False
+ return check_element(want_root, got_root)
+
+ def check_output_json(self, want, got, optionsflags):
+ "Tries to compare want and got as if they were JSON-encoded data"
+ want, got = self._strip_quotes(want, got)
+ try:
+ want_json = simplejson.loads(want)
+ got_json = simplejson.loads(got)
+ except:
+ return False
+ return want_json == got_json
+
+ def _strip_quotes(self, want, got):
+ """
+ Strip quotes of doctests output values:
+
+ >>> o = OutputChecker()
+ >>> o._strip_quotes("'foo'")
+ "foo"
+ >>> o._strip_quotes('"foo"')
+ "foo"
+ >>> o._strip_quotes("u'foo'")
+ "foo"
+ >>> o._strip_quotes('u"foo"')
+ "foo"
+ """
+ def is_quoted_string(s):
+ s = s.strip()
+ return (len(s) >= 2
+ and s[0] == s[-1]
+ and s[0] in ('"', "'"))
+
+ def is_quoted_unicode(s):
+ s = s.strip()
+ return (len(s) >= 3
+ and s[0] == 'u'
+ and s[1] == s[-1]
+ and s[1] in ('"', "'"))
+
+ if is_quoted_string(want) and is_quoted_string(got):
+ want = want.strip()[1:-1]
+ got = got.strip()[1:-1]
+ elif is_quoted_unicode(want) and is_quoted_unicode(got):
+ want = want.strip()[2:-1]
+ got = got.strip()[2:-1]
+ return want, got
+
+
+class DocTestRunner(doctest.DocTestRunner):
+ def __init__(self, *args, **kwargs):
+ doctest.DocTestRunner.__init__(self, *args, **kwargs)
+ self.optionflags = doctest.ELLIPSIS
+
+ def report_unexpected_exception(self, out, test, example, exc_info):
+ doctest.DocTestRunner.report_unexpected_exception(self, out, test,
+ example, exc_info)
+ # Rollback, in case of database errors. Otherwise they'd have
+ # side effects on other tests.
+ transaction.rollback_unless_managed()
+
+class TestCase(unittest.TestCase):
+ def _pre_setup(self):
+ """Performs any pre-test setup. This includes:
+
+ * Flushing the database.
+ * If the Test Case class has a 'fixtures' member, installing the
+ named fixtures.
+ * If the Test Case class has a 'urls' member, replace the
+ ROOT_URLCONF with it.
+ * Clearing the mail test outbox.
+ """
+ call_command('flush', verbosity=0, interactive=False)
+ if hasattr(self, 'fixtures'):
+ # We have to use this slightly awkward syntax due to the fact
+ # that we're using *args and **kwargs together.
+ call_command('loaddata', *self.fixtures, **{'verbosity': 0})
+ if hasattr(self, 'urls'):
+ self._old_root_urlconf = settings.ROOT_URLCONF
+ settings.ROOT_URLCONF = self.urls
+ clear_url_caches()
+ mail.outbox = []
+
+ def __call__(self, result=None):
+ """
+ Wrapper around default __call__ method to perform common Django test
+ set up. This means that user-defined Test Cases aren't required to
+ include a call to super().setUp().
+ """
+ self.client = Client()
+ try:
+ self._pre_setup()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception:
+ import sys
+ result.addError(self, sys.exc_info())
+ return
+ super(TestCase, self).__call__(result)
+ try:
+ self._post_teardown()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception:
+ import sys
+ result.addError(self, sys.exc_info())
+ return
+
+ def _post_teardown(self):
+ """ Performs any post-test things. This includes:
+
+ * Putting back the original ROOT_URLCONF if it was changed.
+ """
+ if hasattr(self, '_old_root_urlconf'):
+ settings.ROOT_URLCONF = self._old_root_urlconf
+ clear_url_caches()
+
+ def assertRedirects(self, response, expected_url, status_code=302,
+ target_status_code=200, host=None):
+ """Asserts that a response redirected to a specific URL, and that the
+ redirect URL can be loaded.
+
+ Note that assertRedirects won't work for external links since it uses
+ TestClient to do a request.
+ """
+ self.assertEqual(response.status_code, status_code,
+ ("Response didn't redirect as expected: Response code was %d"
+ " (expected %d)" % (response.status_code, status_code)))
+ url = response['Location']
+ scheme, netloc, path, query, fragment = urlsplit(url)
+ e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
+ if not (e_scheme or e_netloc):
+ expected_url = urlunsplit(('http', host or 'testserver', e_path,
+ e_query, e_fragment))
+ self.assertEqual(url, expected_url,
+ "Response redirected to '%s', expected '%s'" % (url, expected_url))
+
+ # Get the redirection page, using the same client that was used
+ # to obtain the original response.
+ redirect_response = response.client.get(path, QueryDict(query))
+ self.assertEqual(redirect_response.status_code, target_status_code,
+ ("Couldn't retrieve redirection page '%s': response code was %d"
+ " (expected %d)") %
+ (path, redirect_response.status_code, target_status_code))
+
+ def assertContains(self, response, text, count=None, status_code=200):
+ """
+ Asserts that a response indicates that a page was retrieved
+ successfully, (i.e., the HTTP status code was as expected), and that
+ ``text`` occurs ``count`` times in the content of the response.
+ If ``count`` is None, the count doesn't matter - the assertion is true
+ if the text occurs at least once in the response.
+ """
+ self.assertEqual(response.status_code, status_code,
+ "Couldn't retrieve page: Response code was %d (expected %d)'" %
+ (response.status_code, status_code))
+ real_count = response.content.count(text)
+ if count is not None:
+ self.assertEqual(real_count, count,
+ "Found %d instances of '%s' in response (expected %d)" %
+ (real_count, text, count))
+ else:
+ self.failUnless(real_count != 0,
+ "Couldn't find '%s' in response" % text)
+
+ def assertNotContains(self, response, text, status_code=200):
+ """
+ Asserts that a response indicates that a page was retrieved
+ successfully, (i.e., the HTTP status code was as expected), and that
+ ``text`` doesn't occurs in the content of the response.
+ """
+ self.assertEqual(response.status_code, status_code,
+ "Couldn't retrieve page: Response code was %d (expected %d)'" %
+ (response.status_code, status_code))
+ self.assertEqual(response.content.count(text), 0,
+ "Response should not contain '%s'" % text)
+
+ def assertFormError(self, response, form, field, errors):
+ """
+ Asserts that a form used to render the response has a specific field
+ error.
+ """
+ # Put context(s) into a list to simplify processing.
+ contexts = to_list(response.context)
+ if not contexts:
+ self.fail('Response did not use any contexts to render the'
+ ' response')
+
+ # Put error(s) into a list to simplify processing.
+ errors = to_list(errors)
+
+ # Search all contexts for the error.
+ found_form = False
+ for i,context in enumerate(contexts):
+ if form not in context:
+ continue
+ found_form = True
+ for err in errors:
+ if field:
+ if field in context[form].errors:
+ field_errors = context[form].errors[field]
+ self.failUnless(err in field_errors,
+ "The field '%s' on form '%s' in"
+ " context %d does not contain the"
+ " error '%s' (actual errors: %s)" %
+ (field, form, i, err,
+ repr(field_errors)))
+ elif field in context[form].fields:
+ self.fail("The field '%s' on form '%s' in context %d"
+ " contains no errors" % (field, form, i))
+ else:
+ self.fail("The form '%s' in context %d does not"
+ " contain the field '%s'" %
+ (form, i, field))
+ else:
+ non_field_errors = context[form].non_field_errors()
+ self.failUnless(err in non_field_errors,
+ "The form '%s' in context %d does not contain the"
+ " non-field error '%s' (actual errors: %s)" %
+ (form, i, err, non_field_errors))
+ if not found_form:
+ self.fail("The form '%s' was not used to render the response" %
+ form)
+
+ def assertTemplateUsed(self, response, template_name):
+ """
+ Asserts that the template with the provided name was used in rendering
+ the response.
+ """
+ template_names = [t.name for t in to_list(response.template)]
+ if not template_names:
+ self.fail('No templates used to render the response')
+ self.failUnless(template_name in template_names,
+ (u"Template '%s' was not a template used to render the response."
+ u" Actual template(s) used: %s") % (template_name,
+ u', '.join(template_names)))
+
+ def assertTemplateNotUsed(self, response, template_name):
+ """
+ Asserts that the template with the provided name was NOT used in
+ rendering the response.
+ """
+ template_names = [t.name for t in to_list(response.template)]
+ self.failIf(template_name in template_names,
+ (u"Template '%s' was used unexpectedly in rendering the"
+ u" response") % template_name)