# Copyright (C) 2007 AG Projects. See LICENSE for details. # """GNUTLS data validators""" __all__ = ['function_args', 'method_args', 'none', 'ignore', 'list_of', 'one_of', 'ProtocolListValidator', 'KeyExchangeListValidator', 'CipherListValidator', 'MACListValidator', 'CompressionListValidator'] from gnutls.constants import * # # Priority list validators. They take a tuple or list on input and output a # tuple with the same elements if they check valid, else raise an exception. # class ProtocolListValidator(tuple): _protocols = set((PROTO_TLS1_1, PROTO_TLS1_0, PROTO_SSL3)) def __new__(cls, arg): if not isinstance(arg, (tuple, list)): raise TypeError("Argument must be a tuple or list") if not arg: raise ValueError("Protocol list cannot be empty") if not cls._protocols.issuperset(set(arg)): raise ValueError("Got invalid protocol") return tuple.__new__(cls, arg) class KeyExchangeListValidator(tuple): _algorithms = set((KX_RSA, KX_DHE_DSS, KX_DHE_RSA, KX_RSA_EXPORT, KX_ANON_DH)) def __new__(cls, arg): if not isinstance(arg, (tuple, list)): raise TypeError("Argument must be a tuple or list") if not arg: raise ValueError("Key exchange algorithm list cannot be empty") if not cls._algorithms.issuperset(set(arg)): raise ValueError("Got invalid key exchange algorithm") return tuple.__new__(cls, arg) class CipherListValidator(tuple): _ciphers = set((CIPHER_AES_128_CBC, CIPHER_3DES_CBC, CIPHER_ARCFOUR_128, CIPHER_AES_256_CBC, CIPHER_DES_CBC)) def __new__(cls, arg): if not isinstance(arg, (tuple, list)): raise TypeError("Argument must be a tuple or list") if not arg: raise ValueError("Cipher list cannot be empty") if not cls._ciphers.issuperset(set(arg)): raise ValueError("Got invalid cipher") return tuple.__new__(cls, arg) class MACListValidator(tuple): _algorithms = set((MAC_SHA1, MAC_MD5, MAC_RMD160)) def __new__(cls, arg): if not isinstance(arg, (tuple, list)): raise TypeError("Argument must be a tuple or list") if not arg: raise ValueError("MAC algorithm list cannot be empty") if not cls._algorithms.issuperset(set(arg)): raise ValueError("Got invalid MAC algorithm") return tuple.__new__(cls, arg) class CompressionListValidator(tuple): _compressions = set((COMP_DEFLATE, COMP_LZO, COMP_NULL)) def __new__(cls, arg): if not isinstance(arg, (tuple, list)): raise TypeError("Argument must be a tuple or list") if not arg: raise ValueError("Compression list cannot be empty") if not cls._compressions.issuperset(set(arg)): raise ValueError("Got invalid compression") return tuple.__new__(cls, arg) # # Argument validating # # Helper functions (internal use) # def isclass(obj): return hasattr(obj, '__bases__') or isinstance(obj, type) # Internal validator classes # class Validator(object): _registered = [] def __init__(self, typ): self.type = typ def check(self, value): return False @staticmethod def can_validate(typ): return False @classmethod def register(cls, validator): cls._registered.append(validator) @classmethod def get(cls, typ): for validator in cls._registered: if validator.can_validate(typ): return validator(typ) else: return None @staticmethod def join_names(names): if type(names) in (tuple, list): if len(names) <= 2: return ' or '.join(names) else: return ' or '.join((', '.join(names[:-1]), names[-1])) else: return names def _type_names(self): if isinstance(self.type, tuple): return self.join_names([t.__name__.replace('NoneType', 'None') for t in self.type]) else: return self.type.__name__.replace('NoneType', 'None') @property def name(self): name = self._type_names() if name.startswith('None'): prefix = '' elif name[0] in ('a', 'e', 'i', 'o', 'u'): prefix = 'an ' else: prefix = 'a ' return prefix + name class IgnoringValidator(Validator): def __init__(self, typ): self.type = none def check(self, value): return True @staticmethod def can_validate(obj): return obj is ignore class TypeValidator(Validator): def check(self, value): return isinstance(value, self.type) @staticmethod def can_validate(obj): return isclass(obj) class MultiTypeValidator(TypeValidator): @staticmethod def can_validate(obj): return isinstance(obj, tuple) and not filter(lambda x: not isclass(x), obj) class OneOfValidator(Validator): def __init__(self, typ): self.type = typ.type def check(self, value): return value in self.type @staticmethod def can_validate(obj): return isinstance(obj, one_of) @property def name(self): return 'one of %s' % self.join_names(["`%r'" % e for e in self.type]) class ListOfValidator(Validator): def __init__(self, typ): self.type = typ.type def check(self, value): return isinstance(value, (tuple, list)) and not filter(lambda x: not isinstance(x, self.type), value) @staticmethod def can_validate(obj): return isinstance(obj, list_of) @property def name(self): return 'a list of %s' % self._type_names() class ComplexValidator(Validator): def __init__(self, typ): self.type = [Validator.get(x) for x in typ] def check(self, value): return bool(sum(t.check(value) for t in self.type)) @staticmethod def can_validate(obj): return isinstance(obj, tuple) and not filter(lambda x: Validator.get(x) is None, obj) @property def name(self): return self.join_names([x.name for x in self.type]) Validator.register(IgnoringValidator) Validator.register(TypeValidator) Validator.register(MultiTypeValidator) Validator.register(OneOfValidator) Validator.register(ListOfValidator) Validator.register(ComplexValidator) # Extra types to be used with argument validating decorators # none = type(None) class one_of(object): def __init__(self, *args): if len(args) < 2: raise ValueError("one_of must have at least 2 arguments") self.type = args class list_of(object): def __init__(self, *args): if filter(lambda x: not isclass(x), args): raise TypeError("list_of arguments must be types") if len(args) == 1: self.type = args[0] else: self.type = args ignore = type('ignore', (), {})() # Helpers for writing well behaved decorators # def decorator(func): """A syntactic marker with no other effect than improving readability.""" return func def preserve_signature(func): """Preserve the original function signature and attributes in decorator wrappers.""" from inspect import getargspec, formatargspec from gnutls.constants import GNUTLSConstant constants = [c for c in (getargspec(func)[3] or []) if isinstance(c, GNUTLSConstant)] signature = formatargspec(*getargspec(func))[1:-1] parameters = formatargspec(*getargspec(func), **{'formatvalue': lambda value: ""})[1:-1] def fix_signature(wrapper): if constants: ## import the required GNUTLSConstants used as function default arguments code = "from gnutls.constants import %s\n" % ', '.join(c.name for c in constants) exec code in locals(), locals() code = "def %s(%s): return wrapper(%s)\nnew_wrapper = %s\n" % (func.__name__, signature, parameters, func.__name__) exec code in locals(), locals() new_wrapper.__name__ = func.__name__ new_wrapper.__doc__ = func.__doc__ new_wrapper.__module__ = func.__module__ new_wrapper.__dict__.update(func.__dict__) return new_wrapper return fix_signature # Argument validating decorators # def _callable_args(*args, **kwargs): """Internal function used by argument checking decorators""" start = kwargs.get('_start', 0) validators = [] for i, arg in enumerate(args): validator = Validator.get(arg) if validator is None: raise TypeError("unsupported type `%r' at position %d for argument checking decorator" % (arg, i+1)) validators.append(validator) def check_args_decorator(func): @preserve_signature(func) def check_args(*func_args): pos = start for validator in validators: if not validator.check(func_args[pos]): raise TypeError("argument %d must be %s" % (pos+1-start, validator.name)) pos += 1 return func(*func_args) return check_args return check_args_decorator @decorator def method_args(*args): """Check class or instance method arguments""" return _callable_args(*args, **{'_start': 1}) @decorator def function_args(*args): """Check functions or staticmethod arguments""" return _callable_args(*args)