diff options
| author | kali <kali@leap.se> | 2012-07-22 21:10:15 -0700 | 
|---|---|---|
| committer | kali <kali@leap.se> | 2012-07-22 21:10:15 -0700 | 
| commit | c46d8da153ac658c8bd145376e22b1218db1090a (patch) | |
| tree | 0943a4a866d9f3b1bc590c1c23f810ca13635f9e /tests/support.py | |
initial import
Diffstat (limited to 'tests/support.py')
| -rw-r--r-- | tests/support.py | 111 | 
1 files changed, 111 insertions, 0 deletions
diff --git a/tests/support.py b/tests/support.py new file mode 100644 index 00000000..8ac49669 --- /dev/null +++ b/tests/support.py @@ -0,0 +1,111 @@ +# code borrowed from python stdlib tests +# I think we're not using it at the end...  +# XXX Review and Remove + +import contextlib +import socket +import sys +import unittest + + +HOST = "localhost" + + +class TestFailed(Exception): +    """Test failed.""" + + +def bind_port(sock, host=HOST): +    """Bind the socket to a free port and return the port number.  Relies on +    ephemeral ports in order to ensure we are using an unbound port.  This is +    important as many tests may be running simultaneously, especially in a +    buildbot environment.  This method raises an exception if the sock.family +    is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR +    or SO_REUSEPORT set on it.  Tests should *never* set these socket options +    for TCP/IP sockets.  The only case for setting these options is testing +    multicasting via multiple UDP sockets. + +    Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. +    on Windows), it will be set on the socket.  This will prevent anyone else +    from bind()'ing to our host/port for the duration of the test. +    """ + +    if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: +        if hasattr(socket, 'SO_REUSEADDR'): +            if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: +                raise TestFailed("tests should never set the SO_REUSEADDR "   \ +                                 "socket option on TCP/IP sockets!") +        if hasattr(socket, 'SO_REUSEPORT'): +            if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: +                raise TestFailed("tests should never set the SO_REUSEPORT "   \ +                                 "socket option on TCP/IP sockets!") +        if hasattr(socket, 'SO_EXCLUSIVEADDRUSE'): +            sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + +    sock.bind((host, 0)) +    port = sock.getsockname()[1] +    return port + + +def _run_suite(suite): +    """Run tests from a unittest.TestSuite-derived class.""" +    runner = unittest.TextTestRunner(sys.stdout, verbosity=2, +                                         failfast=False) +    result = runner.run(suite) +    if not result.wasSuccessful(): +        if len(result.errors) == 1 and not result.failures: +            err = result.errors[0][1] +        elif len(result.failures) == 1 and not result.errors: +            err = result.failures[0][1] +        else: +            err = "multiple errors occurred" +        raise TestFailed(err) + + +def run_unittest(*classes): +    """Run tests from unittest.TestCase-derived classes.""" +    valid_types = (unittest.TestSuite, unittest.TestCase) +    suite = unittest.TestSuite() +    for cls in classes: +        if isinstance(cls, str): +            if cls in sys.modules: +                suite.addTest(unittest.findTestCases(sys.modules[cls])) +            else: +                raise ValueError("str arguments must be keys in sys.modules") +        elif isinstance(cls, valid_types): +            suite.addTest(cls) +        else: +            suite.addTest(unittest.makeSuite(cls)) + +    _run_suite(suite) + + +@contextlib.contextmanager +def captured_output(stream_name): +    """Return a context manager used by captured_stdout/stdin/stderr +    that temporarily replaces the sys stream *stream_name* with a StringIO.""" +    import io +    orig_stdout = getattr(sys, stream_name) +    setattr(sys, stream_name, io.StringIO()) +    try: +        yield getattr(sys, stream_name) +    finally: +        setattr(sys, stream_name, orig_stdout) + + +def captured_stdout(): +    """Capture the output of sys.stdout: + +       with captured_stdout() as s: +           print("hello") +       self.assertEqual(s.getvalue(), "hello") +    """ +    return captured_output("stdout") + + +def captured_stderr(): +    return captured_output("stderr") + + +def captured_stdin(): +    return captured_output("stdin")  | 
