diff --git a/apns.py b/apns.py old mode 100644 new mode 100755 index 33655b1..532053b --- a/apns.py +++ b/apns.py @@ -267,37 +267,44 @@ def write(self, string): return self._connection().write(string) -class PayloadAlert(object): - def __init__(self, body=None, title = None, subtitle = None, action_loc_key=None, loc_key=None, - loc_args=None, launch_image=None): - super(PayloadAlert, self).__init__() - - self.body = body - self.title = title - self.subtitle = subtitle - self.action_loc_key = action_loc_key - self.loc_key = loc_key - self.loc_args = loc_args - self.launch_image = launch_image +class PayloadAlert(dict): + """ + Payload for APNS alert. + https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/PayloadKeyReference.html + """ + def __init__(self, + body=None, + title=None, + subtitle=None, + action_loc_key=None, + loc_key=None, + loc_args=None, + launch_image=None, + title_loc_key=None, + title_loc_args=None): + dict_ = { + 'body': body, + 'title': title, + 'subtitle': subtitle, + 'action-loc-key': action_loc_key, + 'loc-key': loc_key, + 'loc-args': loc_args, + 'launch-image': launch_image, + 'title-loc-key': title_loc_key, + 'title-loc-args': title_loc_args + } + + # init dictionary with non None items + super(PayloadAlert, self).__init__( + { + key: value for (key, value) + in dict_.items() if value is not None + } + ) def dict(self): - d = {} - - if self.body: - d['body'] = self.body - if self.title: - d['title'] = self.title - if self.subtitle: - d['subtitle'] = self.subtitle - if self.action_loc_key: - d['action-loc-key'] = self.action_loc_key - if self.loc_key: - d['loc-key'] = self.loc_key - if self.loc_args: - d['loc-args'] = self.loc_args - if self.launch_image: - d['launch-image'] = self.launch_image - return d + return self + class PayloadTooLargeError(Exception): def __init__(self, payload_size): diff --git a/setup.py b/setup.py index 1d66940..644ff24 100644 --- a/setup.py +++ b/setup.py @@ -10,5 +10,5 @@ py_modules = ['apns'], scripts = ['apns-send'], url = 'http://29.io/', - version = '2.0.1', + version = 'canwehatch.2.0.2', ) diff --git a/tests.py b/tests.py old mode 100644 new mode 100755 index fb17a54..6a6254c --- a/tests.py +++ b/tests.py @@ -5,6 +5,7 @@ from random import random import hashlib +import json import os import time import unittest @@ -130,6 +131,11 @@ def testPayloadAlert(self): self.assertTrue('body' not in d) self.assertEqual(d['loc-key'], 'wibble') + def testPayloadAlertJSONSerializable(self): + pa = PayloadAlert('foo', action_loc_key='bar', loc_key='wibble', + loc_args=['king', 'kong'], launch_image='wobble') + self.assertEqual(pa, json.loads(json.dumps(pa))) + def testPayload(self): # Payload with just alert p = Payload(alert=PayloadAlert('foo')) @@ -188,9 +194,19 @@ def testFrame(self): frame = Frame() frame.add_item(token_hex, payload, identifier, expiry, priority) - f1 = bytearray(b'\x02\x00\x00\x00t\x01\x00 \xb5\xbb\x9d\x80\x14\xa0\xf9\xb1\xd6\x1e!\xe7\x96\xd7\x8d\xcc\xdf\x13R\xf2<\xd3(\x12\xf4\x85\x0b\x87\x8a\xe4\x94L\x02\x00<{"aps":{"sound":"default","badge":4,"alert":"Hello World!"}}\x03\x00\x04\x00\x00\x00\x01\x04\x00\x04\x00\x00\x0e\x10\x05\x00\x01\n') - f2 = bytearray(b'\x02\x00\x00\x00t\x01\x00 \xb5\xbb\x9d\x80\x14\xa0\xf9\xb1\xd6\x1e!\xe7\x96\xd7\x8d\xcc\xdf\x13R\xf2<\xd3(\x12\xf4\x85\x0b\x87\x8a\xe4\x94L\x02\x00<{"aps":{"sound":"default","alert":"Hello World!","badge":4}}\x03\x00\x04\x00\x00\x00\x01\x04\x00\x04\x00\x00\x0e\x10\x05\x00\x01\n') - self.assertTrue(f1 == frame.get_frame() or f2 == frame.get_frame()) + frame_bytes = frame.get_frame() + + prefix = frame_bytes[:43] + data = frame_bytes[43: -18] + postfix = frame_bytes[-18:] + + self.assertEqual(prefix, b'\x02\x00\x00\x00t\x01\x00 \xb5\xbb\x9d\x80\x14\xa0\xf9\xb1\xd6\x1e!\xe7\x96\xd7\x8d\xcc\xdf\x13R\xf2<\xd3(\x12\xf4\x85\x0b\x87\x8a\xe4\x94L\x02\x00<') + self.assertEqual( + json.loads(data.decode()), + json.loads('{"aps":{"sound":"default","badge":4,' + '"alert":"Hello World!"}}') + ) + self.assertEqual(postfix, b'\x03\x00\x04\x00\x00\x00\x01\x04\x00\x04\x00\x00\x0e\x10\x05\x00\x01\n') def testPayloadTooLargeError(self): # The maximum size of the JSON payload is MAX_PAYLOAD_LENGTH @@ -209,5 +225,6 @@ def testPayloadTooLargeError(self): self.assertRaises(PayloadTooLargeError, Payload, u'\u0100' * (int(max_raw_payload_bytes / 2) + 1)) + if __name__ == '__main__': unittest.main()