diff options
Diffstat (limited to 'webapp/django/test/testcases.py')
-rw-r--r-- | webapp/django/test/testcases.py | 356 |
1 files changed, 356 insertions, 0 deletions
diff --git a/webapp/django/test/testcases.py b/webapp/django/test/testcases.py new file mode 100644 index 0000000000..dcab078553 --- /dev/null +++ b/webapp/django/test/testcases.py @@ -0,0 +1,356 @@ +import re +import unittest +from urlparse import urlsplit, urlunsplit +from xml.dom.minidom import parseString, Node + +from django.conf import settings +from django.core import mail +from django.core.management import call_command +from django.core.urlresolvers import clear_url_caches +from django.db import transaction +from django.http import QueryDict +from django.test import _doctest as doctest +from django.test.client import Client +from django.utils import simplejson + +normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s) + +def to_list(value): + """ + Puts value into a list if it's not already one. + Returns an empty list if value is None. + """ + if value is None: + value = [] + elif not isinstance(value, list): + value = [value] + return value + + +class OutputChecker(doctest.OutputChecker): + def check_output(self, want, got, optionflags): + "The entry method for doctest output checking. Defers to a sequence of child checkers" + checks = (self.check_output_default, + self.check_output_long, + self.check_output_xml, + self.check_output_json) + for check in checks: + if check(want, got, optionflags): + return True + return False + + def check_output_default(self, want, got, optionflags): + "The default comparator provided by doctest - not perfect, but good for most purposes" + return doctest.OutputChecker.check_output(self, want, got, optionflags) + + def check_output_long(self, want, got, optionflags): + """Doctest does an exact string comparison of output, which means long + integers aren't equal to normal integers ("22L" vs. "22"). The + following code normalizes long integers so that they equal normal + integers. + """ + return normalize_long_ints(want) == normalize_long_ints(got) + + def check_output_xml(self, want, got, optionsflags): + """Tries to do a 'xml-comparision' of want and got. Plain string + comparision doesn't always work because, for example, attribute + ordering should not be important. + + Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py + """ + _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+') + def norm_whitespace(v): + return _norm_whitespace_re.sub(' ', v) + + def child_text(element): + return ''.join([c.data for c in element.childNodes + if c.nodeType == Node.TEXT_NODE]) + + def children(element): + return [c for c in element.childNodes + if c.nodeType == Node.ELEMENT_NODE] + + def norm_child_text(element): + return norm_whitespace(child_text(element)) + + def attrs_dict(element): + return dict(element.attributes.items()) + + def check_element(want_element, got_element): + if want_element.tagName != got_element.tagName: + return False + if norm_child_text(want_element) != norm_child_text(got_element): + return False + if attrs_dict(want_element) != attrs_dict(got_element): + return False + want_children = children(want_element) + got_children = children(got_element) + if len(want_children) != len(got_children): + return False + for want, got in zip(want_children, got_children): + if not check_element(want, got): + return False + return True + + want, got = self._strip_quotes(want, got) + want = want.replace('\\n','\n') + got = got.replace('\\n','\n') + + # If the string is not a complete xml document, we may need to add a + # root element. This allow us to compare fragments, like "<foo/><bar/>" + if not want.startswith('<?xml'): + wrapper = '<root>%s</root>' + want = wrapper % want + got = wrapper % got + + # Parse the want and got strings, and compare the parsings. + try: + want_root = parseString(want).firstChild + got_root = parseString(got).firstChild + except: + return False + return check_element(want_root, got_root) + + def check_output_json(self, want, got, optionsflags): + "Tries to compare want and got as if they were JSON-encoded data" + want, got = self._strip_quotes(want, got) + try: + want_json = simplejson.loads(want) + got_json = simplejson.loads(got) + except: + return False + return want_json == got_json + + def _strip_quotes(self, want, got): + """ + Strip quotes of doctests output values: + + >>> o = OutputChecker() + >>> o._strip_quotes("'foo'") + "foo" + >>> o._strip_quotes('"foo"') + "foo" + >>> o._strip_quotes("u'foo'") + "foo" + >>> o._strip_quotes('u"foo"') + "foo" + """ + def is_quoted_string(s): + s = s.strip() + return (len(s) >= 2 + and s[0] == s[-1] + and s[0] in ('"', "'")) + + def is_quoted_unicode(s): + s = s.strip() + return (len(s) >= 3 + and s[0] == 'u' + and s[1] == s[-1] + and s[1] in ('"', "'")) + + if is_quoted_string(want) and is_quoted_string(got): + want = want.strip()[1:-1] + got = got.strip()[1:-1] + elif is_quoted_unicode(want) and is_quoted_unicode(got): + want = want.strip()[2:-1] + got = got.strip()[2:-1] + return want, got + + +class DocTestRunner(doctest.DocTestRunner): + def __init__(self, *args, **kwargs): + doctest.DocTestRunner.__init__(self, *args, **kwargs) + self.optionflags = doctest.ELLIPSIS + + def report_unexpected_exception(self, out, test, example, exc_info): + doctest.DocTestRunner.report_unexpected_exception(self, out, test, + example, exc_info) + # Rollback, in case of database errors. Otherwise they'd have + # side effects on other tests. + transaction.rollback_unless_managed() + +class TestCase(unittest.TestCase): + def _pre_setup(self): + """Performs any pre-test setup. This includes: + + * Flushing the database. + * If the Test Case class has a 'fixtures' member, installing the + named fixtures. + * If the Test Case class has a 'urls' member, replace the + ROOT_URLCONF with it. + * Clearing the mail test outbox. + """ + call_command('flush', verbosity=0, interactive=False) + if hasattr(self, 'fixtures'): + # We have to use this slightly awkward syntax due to the fact + # that we're using *args and **kwargs together. + call_command('loaddata', *self.fixtures, **{'verbosity': 0}) + if hasattr(self, 'urls'): + self._old_root_urlconf = settings.ROOT_URLCONF + settings.ROOT_URLCONF = self.urls + clear_url_caches() + mail.outbox = [] + + def __call__(self, result=None): + """ + Wrapper around default __call__ method to perform common Django test + set up. This means that user-defined Test Cases aren't required to + include a call to super().setUp(). + """ + self.client = Client() + try: + self._pre_setup() + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + import sys + result.addError(self, sys.exc_info()) + return + super(TestCase, self).__call__(result) + try: + self._post_teardown() + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + import sys + result.addError(self, sys.exc_info()) + return + + def _post_teardown(self): + """ Performs any post-test things. This includes: + + * Putting back the original ROOT_URLCONF if it was changed. + """ + if hasattr(self, '_old_root_urlconf'): + settings.ROOT_URLCONF = self._old_root_urlconf + clear_url_caches() + + def assertRedirects(self, response, expected_url, status_code=302, + target_status_code=200, host=None): + """Asserts that a response redirected to a specific URL, and that the + redirect URL can be loaded. + + Note that assertRedirects won't work for external links since it uses + TestClient to do a request. + """ + self.assertEqual(response.status_code, status_code, + ("Response didn't redirect as expected: Response code was %d" + " (expected %d)" % (response.status_code, status_code))) + url = response['Location'] + scheme, netloc, path, query, fragment = urlsplit(url) + e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url) + if not (e_scheme or e_netloc): + expected_url = urlunsplit(('http', host or 'testserver', e_path, + e_query, e_fragment)) + self.assertEqual(url, expected_url, + "Response redirected to '%s', expected '%s'" % (url, expected_url)) + + # Get the redirection page, using the same client that was used + # to obtain the original response. + redirect_response = response.client.get(path, QueryDict(query)) + self.assertEqual(redirect_response.status_code, target_status_code, + ("Couldn't retrieve redirection page '%s': response code was %d" + " (expected %d)") % + (path, redirect_response.status_code, target_status_code)) + + def assertContains(self, response, text, count=None, status_code=200): + """ + Asserts that a response indicates that a page was retrieved + successfully, (i.e., the HTTP status code was as expected), and that + ``text`` occurs ``count`` times in the content of the response. + If ``count`` is None, the count doesn't matter - the assertion is true + if the text occurs at least once in the response. + """ + self.assertEqual(response.status_code, status_code, + "Couldn't retrieve page: Response code was %d (expected %d)'" % + (response.status_code, status_code)) + real_count = response.content.count(text) + if count is not None: + self.assertEqual(real_count, count, + "Found %d instances of '%s' in response (expected %d)" % + (real_count, text, count)) + else: + self.failUnless(real_count != 0, + "Couldn't find '%s' in response" % text) + + def assertNotContains(self, response, text, status_code=200): + """ + Asserts that a response indicates that a page was retrieved + successfully, (i.e., the HTTP status code was as expected), and that + ``text`` doesn't occurs in the content of the response. + """ + self.assertEqual(response.status_code, status_code, + "Couldn't retrieve page: Response code was %d (expected %d)'" % + (response.status_code, status_code)) + self.assertEqual(response.content.count(text), 0, + "Response should not contain '%s'" % text) + + def assertFormError(self, response, form, field, errors): + """ + Asserts that a form used to render the response has a specific field + error. + """ + # Put context(s) into a list to simplify processing. + contexts = to_list(response.context) + if not contexts: + self.fail('Response did not use any contexts to render the' + ' response') + + # Put error(s) into a list to simplify processing. + errors = to_list(errors) + + # Search all contexts for the error. + found_form = False + for i,context in enumerate(contexts): + if form not in context: + continue + found_form = True + for err in errors: + if field: + if field in context[form].errors: + field_errors = context[form].errors[field] + self.failUnless(err in field_errors, + "The field '%s' on form '%s' in" + " context %d does not contain the" + " error '%s' (actual errors: %s)" % + (field, form, i, err, + repr(field_errors))) + elif field in context[form].fields: + self.fail("The field '%s' on form '%s' in context %d" + " contains no errors" % (field, form, i)) + else: + self.fail("The form '%s' in context %d does not" + " contain the field '%s'" % + (form, i, field)) + else: + non_field_errors = context[form].non_field_errors() + self.failUnless(err in non_field_errors, + "The form '%s' in context %d does not contain the" + " non-field error '%s' (actual errors: %s)" % + (form, i, err, non_field_errors)) + if not found_form: + self.fail("The form '%s' was not used to render the response" % + form) + + def assertTemplateUsed(self, response, template_name): + """ + Asserts that the template with the provided name was used in rendering + the response. + """ + template_names = [t.name for t in to_list(response.template)] + if not template_names: + self.fail('No templates used to render the response') + self.failUnless(template_name in template_names, + (u"Template '%s' was not a template used to render the response." + u" Actual template(s) used: %s") % (template_name, + u', '.join(template_names))) + + def assertTemplateNotUsed(self, response, template_name): + """ + Asserts that the template with the provided name was NOT used in + rendering the response. + """ + template_names = [t.name for t in to_list(response.template)] + self.failIf(template_name in template_names, + (u"Template '%s' was used unexpectedly in rendering the" + u" response") % template_name) |