summaryrefslogtreecommitdiff
path: root/src/leap/common/events/auth.py
blob: db217caa37e2b8d850dea25efdb20ade1e692055 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding: utf-8 -*-
# auth.py
# Copyright (C) 2016 LEAP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""
ZAP authentication, twisted style.
"""
from zmq import PAIR
from zmq.auth.base import Authenticator, VERSION
from txzmq.connection import ZmqConnection
from zmq.utils.strtypes import b, u

from twisted.python import log

from txzmq.connection import ZmqEndpoint, ZmqEndpointType


class TxAuthenticator(ZmqConnection):

    """
    This does not implement the whole ZAP protocol, but the bare minimum that
    we need.
    """

    socketType = PAIR
    address = 'inproc://zeromq.zap.01'
    encoding = 'utf-8'

    def __init__(self, factory, *args, **kw):
        super(TxAuthenticator, self).__init__(factory, *args, **kw)
        self.authenticator = Authenticator(factory.context)
        self.authenticator._send_zap_reply = self._send_zap_reply

    def start(self):
        endpoint = ZmqEndpoint(ZmqEndpointType.bind, self.address)
        self.addEndpoints([endpoint])

    def messageReceived(self, msg):

        command = msg[0]

        if command == b'ALLOW':
            addresses = [u(m, self.encoding) for m in msg[1:]]
            try:
                self.authenticator.allow(*addresses)
            except Exception as e:
                log.err("Failed to allow %s", addresses)

        elif command == b'CURVE':
            domain = u(msg[1], self.encoding)
            location = u(msg[2], self.encoding)
            self.authenticator.configure_curve(domain, location)

    def _send_zap_reply(self, request_id, status_code, status_text,
                        user_id='user'):
        """
        Send a ZAP reply to finish the authentication.
        """
        user_id = user_id if status_code == b'200' else b''
        if isinstance(user_id, unicode):
            user_id = user_id.encode(self.encoding, 'replace')
        metadata = b''  # not currently used
        reply = [VERSION, request_id, status_code, status_text,
                 user_id, metadata]
        self.send(reply)

    def shutdown(self):
        if self.factory:
            super(TxAuthenticator, self).shutdown()


class TxAuthenticationRequest(ZmqConnection):

    socketType = PAIR
    address = 'inproc://zeromq.zap.01'
    encoding = 'utf-8'

    def start(self):
        endpoint = ZmqEndpoint(ZmqEndpointType.connect, self.address)
        self.addEndpoints([endpoint])

    def allow(self, *addresses):
        self.send([b'ALLOW'] + [b(a, self.encoding) for a in addresses])

    def configure_curve(self, domain='*', location=''):
        domain = b(domain, self.encoding)
        location = b(location, self.encoding)
        self.send([b'CURVE', domain, location])