diff options
-rw-r--r-- | src/leap/mail/smtp/gateway.py | 32 | ||||
-rw-r--r-- | src/leap/mail/smtp/tests/test_gateway.py | 57 |
2 files changed, 57 insertions, 32 deletions
diff --git a/src/leap/mail/smtp/gateway.py b/src/leap/mail/smtp/gateway.py index 1a187cf..9d78474 100644 --- a/src/leap/mail/smtp/gateway.py +++ b/src/leap/mail/smtp/gateway.py @@ -204,22 +204,24 @@ class SMTPDelivery(object): signal(proto.SMTP_RECIPIENT_ACCEPTED_ENCRYPTED, user.dest.addrstr) def not_found(failure): - if failure.check(KeyNotFound): - # if key was not found, check config to see if will send anyway - if self._encrypted_only: - signal(proto.SMTP_RECIPIENT_REJECTED, user.dest.addrstr) - raise smtp.SMTPBadRcpt(user.dest.addrstr) - log.msg("Warning: will send an unencrypted message (because " - "encrypted_only' is set to False).") - signal( - proto.SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED, - user.dest.addrstr) - else: - return failure - - d = self._km.get_key(address, OpenPGPKey) # might raise KeyNotFound + failure.trap(KeyNotFound) + + # if key was not found, check config to see if will send anyway + if self._encrypted_only: + signal(proto.SMTP_RECIPIENT_REJECTED, user.dest.addrstr) + raise smtp.SMTPBadRcpt(user.dest.addrstr) + log.msg("Warning: will send an unencrypted message (because " + "encrypted_only' is set to False).") + signal( + proto.SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED, + user.dest.addrstr) + + def encrypt_func(_): + return lambda: EncryptedMessage(user, self._outgoing_mail) + + d = self._km.get_key(address, OpenPGPKey) d.addCallbacks(found, not_found) - d.addCallback(lambda _: EncryptedMessage(user, self._outgoing_mail)) + d.addCallback(encrypt_func) return d def validateFrom(self, helo, origin): diff --git a/src/leap/mail/smtp/tests/test_gateway.py b/src/leap/mail/smtp/tests/test_gateway.py index 8cbff8f..0b9a364 100644 --- a/src/leap/mail/smtp/tests/test_gateway.py +++ b/src/leap/mail/smtp/tests/test_gateway.py @@ -23,7 +23,8 @@ SMTP gateway tests. import re from datetime import datetime -from twisted.internet.defer import inlineCallbacks, fail +from twisted.internet import reactor +from twisted.internet.defer import inlineCallbacks, fail, succeed, Deferred from twisted.test import proto_helpers from mock import Mock @@ -70,6 +71,7 @@ class TestSmtpGateway(TestCaseWithKeyManager): % (string, pattern)) raise self.failureException(msg) + @inlineCallbacks def test_gateway_accepts_valid_email(self): """ Test if SMTP server responds correctly for valid interaction. @@ -93,11 +95,11 @@ class TestSmtpGateway(TestCaseWithKeyManager): # snip... transport = proto_helpers.StringTransport() proto.makeConnection(transport) + reply = "" for i, line in enumerate(self.EMAIL_DATA): - proto.lineReceived(line + '\r\n') - self.assertMatch(transport.value(), - '\r\n'.join(SMTP_ANSWERS[0:i + 1]), - 'Did not get expected answer from gateway.') + reply += yield self.getReply(line + '\r\n', proto, transport) + self.assertMatch(reply, '\r\n'.join(SMTP_ANSWERS), + 'Did not get expected answer from gateway.') proto.setTimeout(None) @inlineCallbacks @@ -122,15 +124,16 @@ class TestSmtpGateway(TestCaseWithKeyManager): outgoing_mail=Mock()).buildProtocol(('127.0.0.1', 0)) transport = proto_helpers.StringTransport() proto.makeConnection(transport) - proto.lineReceived(self.EMAIL_DATA[0] + '\r\n') - proto.lineReceived(self.EMAIL_DATA[1] + '\r\n') - proto.lineReceived(self.EMAIL_DATA[2] + '\r\n') + yield self.getReply(self.EMAIL_DATA[0] + '\r\n', proto, transport) + yield self.getReply(self.EMAIL_DATA[1] + '\r\n', proto, transport) + reply = yield self.getReply(self.EMAIL_DATA[2] + '\r\n', + proto, transport) # ensure the address was rejected - lines = transport.value().rstrip().split('\n') self.assertEqual( - '550 Cannot receive for specified address', - lines[-1], + '550 Cannot receive for specified address\r\n', + reply, 'Address should have been rejecetd with appropriate message.') + proto.setTimeout(None) @inlineCallbacks def test_missing_key_accepts_address(self): @@ -153,12 +156,32 @@ class TestSmtpGateway(TestCaseWithKeyManager): False, outgoing_mail=Mock()).buildProtocol(('127.0.0.1', 0)) transport = proto_helpers.StringTransport() proto.makeConnection(transport) - proto.lineReceived(self.EMAIL_DATA[0] + '\r\n') - proto.lineReceived(self.EMAIL_DATA[1] + '\r\n') - proto.lineReceived(self.EMAIL_DATA[2] + '\r\n') + yield self.getReply(self.EMAIL_DATA[0] + '\r\n', proto, transport) + yield self.getReply(self.EMAIL_DATA[1] + '\r\n', proto, transport) + reply = yield self.getReply(self.EMAIL_DATA[2] + '\r\n', + proto, transport) # ensure the address was accepted - lines = transport.value().rstrip().split('\n') self.assertEqual( - '250 Recipient address accepted', - lines[-1], + '250 Recipient address accepted\r\n', + reply, 'Address should have been accepted with appropriate message.') + proto.setTimeout(None) + + def getReply(self, line, proto, transport): + proto.lineReceived(line) + + if line[:4] not in ['HELO', 'MAIL', 'RCPT', 'DATA']: + return succeed("") + + def check_transport(_): + reply = transport.value() + if reply: + transport.clear() + return succeed(reply) + + d = Deferred() + d.addCallback(check_transport) + reactor.callLater(0, lambda: d.callback(None)) + return d + + return check_transport(None) |