summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/leap/mail/mail.py48
1 files changed, 37 insertions, 11 deletions
diff --git a/src/leap/mail/mail.py b/src/leap/mail/mail.py
index 2f190d4..0aede6b 100644
--- a/src/leap/mail/mail.py
+++ b/src/leap/mail/mail.py
@@ -117,6 +117,33 @@ def _unpack_headers(headers_dict):
return headers_l
+def _get_index_for_cdoc(part_map, cdocs_dict):
+ """
+ Get, if possible, the index for a given content-document matching the phash
+ of the passed part_map.
+
+ This is used when we are initializing a MessagePart, because we just pass a
+ reference to the parent message cdocs container and we need to iterate
+ through the cdocs to figure out which content-doc matches the phash of the
+ part we're currently rendering.
+
+ It is also used when recursing through a nested multipart message, because
+ in the initialization of the child MessagePart we pass a dictionary only
+ for the referenced cdoc.
+
+ :param part_map: a dict describing the mapping of the parts for the current
+ message-part.
+ :param cdocs: a dict of content-documents, 0-indexed.
+ :rtype: int
+ """
+ phash = part_map.get('phash', None)
+ if phash:
+ for i, cdoc_wrapper in cdocs_dict.items():
+ if cdoc_wrapper.phash == phash:
+ return i
+ return None
+
+
class MessagePart(object):
# TODO This class should be better abstracted from the data model.
# TODO support arbitrarily nested multiparts (right now we only support
@@ -144,13 +171,7 @@ class MessagePart(object):
self._pmap = part_map
self._cdocs = cdocs
- index = 1
- phash = part_map.get('phash', None)
- if phash:
- for i, cdoc_wrapper in self._cdocs.items():
- if cdoc_wrapper.phash == phash:
- index = i
- break
+ index = _get_index_for_cdoc(part_map, self._cdocs) or 1
self._index = index
def get_size(self):
@@ -171,7 +192,8 @@ class MessagePart(object):
if not multi:
payload = self._get_payload(self._index)
else:
- # XXX uh, multi also... should recurse"
+ # XXX uh, multi also... should recurse.
+ # This needs to be implemented in a more general and elegant way.
raise NotImplementedError
if payload:
payload = _encode_payload(payload)
@@ -190,11 +212,15 @@ class MessagePart(object):
sub_pmap = self._pmap.get("part_map", {})
try:
- part_map = sub_pmap[str(part + 1)]
+ part_map = sub_pmap[str(part)]
except KeyError:
- logger.debug("getSubpart for %s: KeyError" % (part,))
+ log.msg("getSubpart for %s: KeyError" % (part,))
raise IndexError
- return MessagePart(part_map, cdocs={1: self._cdocs.get(part + 1, {})})
+
+ cdoc_index = _get_index_for_cdoc(part_map, self._cdocs)
+ cdoc = self._cdocs.get(cdoc_index, {})
+
+ return MessagePart(part_map, cdocs={1: cdoc})
def _get_payload(self, index):
cdoc_wrapper = self._cdocs.get(index, None)