Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite: APNs: Scoped App Tokens #101

Merged
merged 19 commits into from
May 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Pyright
on: [push, pull_request]
jobs:
pyright:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
cache: 'pip'

- run: |
python -m venv .venv
source .venv/bin/activate
pip install -e '.[test,cli]'
- run: echo "$PWD/.venv/bin" >> $GITHUB_PATH
- uses: jakebailey/pyright-action@v2
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [push, pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -34,4 +34,8 @@ version_file = "pypush/_version.py"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "-q"]
testpaths = ["tests"]
testpaths = ["tests"]

[tool.ruff.lint]
select = ["E", "F", "B", "SIM", "I"]
ignore = ["E501", "B010"]
6 changes: 3 additions & 3 deletions pypush/apns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = ["protocol", "create_apns_connection", "activate"]
__all__ = ["protocol", "create_apns_connection", "activate", "filters"]

from . import protocol
from .lifecycle import create_apns_connection
from . import filters, protocol
from .albert import activate
from .lifecycle import create_apns_connection
14 changes: 7 additions & 7 deletions pypush/apns/_protocol.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import logging
from dataclasses import MISSING, field
from dataclasses import fields as dataclass_fields
from typing import Any, TypeVar, get_origin, get_args, Union
from typing import Any, TypeVar, Union, get_args, get_origin

from pypush.apns.transport import Packet

@@ -67,14 +67,14 @@ def from_packet(cls, packet: Packet):
)

# Check for extra fields
for field in packet.fields:
if field.id not in [
for current_field in packet.fields:
if current_field.id not in [
f.metadata["packet_id"]
for f in dataclass_fields(cls)
if f.metadata is not None and "packet_id" in f.metadata
]:
logging.warning(
f"Unexpected field with packet ID {field.id} in packet {packet}"
f"Unexpected field with packet ID {current_field.id} in packet {packet}"
)
return cls(**field_values)

@@ -122,15 +122,15 @@ def fid(
:param byte_len: The length of the field in bytes (for int fields)
:param default: The default value of the field
"""
if not default == MISSING and not default_factory == MISSING:
if default != MISSING and default_factory != MISSING:
raise ValueError("Cannot specify both default and default_factory")
if not default == MISSING:
if default != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default=default,
repr=repr,
)
if not default_factory == MISSING:
if default_factory != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default_factory=default_factory,
50 changes: 45 additions & 5 deletions pypush/apns/_util.py
Original file line number Diff line number Diff line change
@@ -3,32 +3,72 @@
from typing import Generic, TypeVar

import anyio
from anyio.abc import ObjectSendStream
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from . import filters

T = TypeVar("T")


class BroadcastStream(Generic[T]):
def __init__(self):
def __init__(self, backlog: int = 50):
self.streams: list[ObjectSendStream[T]] = []
self.backlog: list[T] = []
self._backlog_size = backlog

async def broadcast(self, packet):
logging.debug(f"Broadcasting {packet} to {len(self.streams)} streams")
for stream in self.streams:
try:
await stream.send(packet)
except anyio.BrokenResourceError:
self.streams.remove(stream)
logging.error("Broken resource error")
# self.streams.remove(stream)
# If we have a backlog, add the packet to it
if len(self.backlog) >= self._backlog_size:
self.backlog.pop(0)
self.backlog.append(packet)

@asynccontextmanager
async def open_stream(self):
send, recv = anyio.create_memory_object_stream[T]()
async def open_stream(self, backlog: bool = True):
# 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will
# start stalling the APNs connection itself
send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000)
if backlog:
for packet in self.backlog:
await send.send(packet)
self.streams.append(send)
async with recv:
yield recv
self.streams.remove(send)
await send.aclose()


W = TypeVar("W")
F = TypeVar("F")


class FilteredStream(ObjectReceiveStream[F]):
"""
A stream that filters out unwanted items
filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
"""

def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]):
self.source = source
self.filter = filter

async def receive(self) -> F:
async for item in self.source:
if (filtered := self.filter(item)) is not None:
return filtered
raise anyio.EndOfStream

async def aclose(self):
await self.source.aclose()


def exponential_backoff(f):
async def wrapper(*args, **kwargs):
backoff = 1
6 changes: 3 additions & 3 deletions pypush/apns/albert.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import re
import uuid
from base64 import b64decode
from typing import Tuple, Optional
from typing import Optional, Tuple

import httpx
from cryptography import x509
@@ -96,10 +96,10 @@ async def activate(

try:
protocol = re.search("<Protocol>(.*)</Protocol>", resp.text).group(1) # type: ignore
except AttributeError:
except AttributeError as e:
# Search for error text between <b> and </b>
error = re.search("<b>(.*)</b>", resp.text).group(1) # type: ignore
raise Exception(f"Failed to get certificate from Albert: {error}")
raise Exception(f"Failed to get certificate from Albert: {error}") from e

protocol = plistlib.loads(protocol.encode("utf-8"))

44 changes: 44 additions & 0 deletions pypush/apns/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
from typing import Callable, Optional, Type, TypeVar

from pypush.apns import protocol

T1 = TypeVar("T1")
T2 = TypeVar("T2")
Filter = Callable[[T1], Optional[T2]]

# Chain with proper types so that subsequent filters only need to take output type of previous filter
T_IN = TypeVar("T_IN", bound=protocol.Command)
T_MIDDLE = TypeVar("T_MIDDLE", bound=protocol.Command)
T_OUT = TypeVar("T_OUT", bound=protocol.Command)


def chain(first: Filter[T_IN, T_MIDDLE], second: Filter[T_MIDDLE, T_OUT]):
def filter(command: T_IN) -> Optional[T_OUT]:
logging.debug(f"Filtering {command} with {first} and {second}")
filtered = first(command)
if filtered is None:
return None
return second(filtered)

return filter


T = TypeVar("T", bound=protocol.Command)


def cmd(type: Type[T]) -> Filter[protocol.Command, T]:
def filter(command: protocol.Command) -> Optional[T]:
if isinstance(command, type):
return command
return None

return filter


def ALL(c):
return c


def NONE(_):
return None
202 changes: 164 additions & 38 deletions pypush/apns/lifecycle.py
Original file line number Diff line number Diff line change
@@ -6,28 +6,34 @@
import time
import typing
from contextlib import asynccontextmanager
from hashlib import sha1

import anyio
from anyio.abc import TaskGroup
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa

from . import protocol, transport, _util
from . import _util, filters, protocol, transport


@asynccontextmanager
async def create_apns_connection(
certificate: x509.Certificate,
private_key: rsa.RSAPrivateKey,
token: typing.Optional[bytes] = None,
sandbox: bool = False,
courier: typing.Optional[str] = None,
):
async with anyio.create_task_group() as tg:
conn = Connection(tg, certificate, private_key, token, courier)
conn = Connection(
tg, certificate, private_key, token, sandbox, courier
) # Await connected for first time here, so that base token is set
yield conn
tg.cancel_scope.cancel() # Cancel the task group when the context manager exits
await conn.aclose() # Make sure to close the connection after the task group is cancelled
await (
conn.aclose()
) # Make sure to close the connection after the task group is cancelled


class Connection:
@@ -37,26 +43,44 @@ def __init__(
certificate: x509.Certificate,
private_key: rsa.RSAPrivateKey,
token: typing.Optional[bytes] = None,
sandbox: bool = False,
courier: typing.Optional[str] = None,
):

self.certificate = certificate
self.private_key = private_key
self.base_token = token
self._base_token = token

self._filters: dict[str, int] = {} # topic -> use count

self._connected = anyio.Event() # Only use for base_token property

self._conn = None
self._tg = task_group
self._broadcast = _util.BroadcastStream[protocol.Command]()
self._reconnect_lock = anyio.Lock()
self._send_lock = anyio.Lock()

self.sandbox = sandbox
if courier is None:
# Pick a random courier server from 1 to 50
courier = f"{random.randint(1, 50)}-courier.push.apple.com"
courier = (
f"{random.randint(1, 50)}-courier.push.apple.com"
if not sandbox
else f"{random.randint(1, 10)}-courier.sandbox.push.apple.com"
)
logging.debug(f"Using courier: {courier}")
self.courier = courier

self._tg.start_soon(self.reconnect)
self._tg.start_soon(self._ping_task)

@property
async def base_token(self) -> bytes:
if self._base_token is None:
await self._connected.wait()
assert self._base_token is not None
return self._base_token

async def _receive_task(self):
assert self._conn is not None
async for command in self._conn:
@@ -68,17 +92,22 @@ async def _ping_task(self):
while True:
await anyio.sleep(30)
logging.debug("Sending keepalive")
await self.send(protocol.KeepAliveCommand())
await self.receive(protocol.KeepAliveAck)
await self._send(protocol.KeepAliveCommand())
await self._receive(
filters.cmd(protocol.KeepAliveAck), backlog=False
) # Explicitly disable the backlog since we don't want to receive old acks

@_util.exponential_backoff
async def reconnect(self):
async with self._reconnect_lock: # Prevent weird situations where multiple reconnects are happening at once
if self._conn is not None:
logging.warning("Closing existing connection")
await self._conn.aclose()
self._conn = protocol.CommandStream(
await transport.create_courier_connection(courier=self.courier)

self._broadcast.backlog = [] # Clear the backlog

conn = protocol.CommandStream(
await transport.create_courier_connection(self.sandbox, self.courier)
)
cert = self.certificate.public_bytes(serialization.Encoding.DER)
nonce = (
@@ -89,53 +118,150 @@ async def reconnect(self):
signature = b"\x01\x01" + self.private_key.sign(
nonce, padding.PKCS1v15(), hashes.SHA1()
)
await self._conn.send(
await conn.send(
protocol.ConnectCommand(
push_token=self.base_token,
push_token=self._base_token,
state=1,
flags=69,
flags=65, # 69
certificate=cert,
nonce=nonce,
signature=signature,
)
)

# Don't set self._conn until we've sent the connect command
self._conn = conn

self._tg.start_soon(self._receive_task)
ack = await self.receive(protocol.ConnectAck)
ack = await self._receive(
filters.chain(
filters.cmd(protocol.ConnectAck),
lambda c: (
c
if (
c.token == self._base_token
if self._base_token is not None
else True
)
else None
),
)
)
logging.debug(f"Connected with ack: {ack}")
assert ack.status == 0
if self.base_token is None:
self.base_token = ack.token
if self._base_token is None:
self._base_token = ack.token
else:
assert ack.token == self.base_token
assert ack.token == self._base_token
if not self._connected.is_set():
self._connected.set()

await self._update_filter()

async def aclose(self):
if self._conn is not None:
await self._conn.aclose()
# Note: Will be reopened if task group is still running and ping task is still running

T = typing.TypeVar("T", bound=protocol.Command)
T = typing.TypeVar("T")

async def receive_stream(
self, filter: typing.Type[T], max: int = -1
) -> typing.AsyncIterator[T]:
async with self._broadcast.open_stream() as stream:
@asynccontextmanager
async def _receive_stream(
self,
filter: filters.Filter[protocol.Command, T] = lambda c: c,
backlog: bool = True,
):
async with self._broadcast.open_stream(backlog) as stream:
yield _util.FilteredStream(stream, filter)

async def _receive(
self, filter: filters.Filter[protocol.Command, T], backlog: bool = True
):
async with self._receive_stream(filter, backlog) as stream:
async for command in stream:
if isinstance(command, filter):
yield command
max -= 1
if max == 0:
break

async def receive(self, filter: typing.Type[T]) -> T:
async for command in self.receive_stream(filter, 1):
return command
raise ValueError("No matching command received")
return command
raise ValueError("Did not receive expected command")

async def send(self, command: protocol.Command):
async def _send(self, command: protocol.Command):
try:
assert self._conn is not None
await self._conn.send(command)
except Exception as e:
logging.warning(f"Error sending command, reconnecting")
async with self._send_lock:
assert self._conn is not None
await self._conn.send(command)
except Exception:
logging.warning("Error sending command, reconnecting")
await self.reconnect()
await self.send(command)
await self._send(command)

async def _update_filter(self):
await self._send(
protocol.FilterCommand(
token=await self.base_token,
enabled_topic_hashes=[
sha1(topic.encode()).digest() for topic in self._filters
],
)
)

@asynccontextmanager
async def _filter(self, topics: list[str]):
for topic in topics:
self._filters[topic] = self._filters.get(topic, 0) + 1
await self._update_filter()
yield
for topic in topics:
self._filters[topic] -= 1
if self._filters[topic] == 0:
del self._filters[topic]
await self._update_filter()

async def mint_scoped_token(self, topic: str) -> bytes:
topic_hash = sha1(topic.encode()).digest()
await self._send(
protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash)
)
ack = await self._receive(filters.cmd(protocol.ScopedTokenAck))
assert ack.status == 0
return ack.scoped_token

@asynccontextmanager
async def notification_stream(
JJTech0130 marked this conversation as resolved.
Show resolved Hide resolved
self,
topic: str,
token: typing.Optional[bytes] = None,
filter: filters.Filter[
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
):
if token is None:
token = await self.base_token
async with self._filter([topic]), self._receive_stream(
filters.chain(
filters.chain(
filters.chain(
filters.cmd(protocol.SendMessageCommand),
lambda c: c if c.token == token else None,
),
lambda c: (c if c.topic == topic else None),
),
filter,
)
) as stream:
yield stream

async def ack(self, command: protocol.SendMessageCommand, status: int = 0):
await self._send(
protocol.SendMessageAck(status=status, token=command.token, id=command.id)
)

async def expect_notification(
self,
topic: str,
token: typing.Optional[bytes] = None,
filter: filters.Filter[
protocol.SendMessageCommand, protocol.SendMessageCommand
] = filters.ALL,
) -> protocol.SendMessageCommand:
async with self.notification_stream(topic, token, filter) as stream:
command = await stream.receive()
await self.ack(command)
return command
51 changes: 33 additions & 18 deletions pypush/apns/protocol.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
from hashlib import sha1
from typing import Optional, Union

from anyio.abc import ByteStream, ObjectStream
from anyio.abc import ObjectStream

from pypush.apns._protocol import command, fid
from pypush.apns.transport import Packet
@@ -87,12 +87,7 @@ class FilterCommand(Command):

def _lookup_hashes(self, hashes: Optional[list[bytes]]):
JJTech0130 marked this conversation as resolved.
Show resolved Hide resolved
return (
[
KNOWN_TOPICS_LOOKUP[hash] if hash in KNOWN_TOPICS_LOOKUP else hash
for hash in hashes
]
if hashes
else []
[KNOWN_TOPICS_LOOKUP.get(hash, hash) for hash in hashes] if hashes else []
)

@property
@@ -140,6 +135,7 @@ class KeepAliveAck(Command):
PacketType = Packet.Type.KeepAliveAck
unknown: Optional[int] = fid(1)


@command
@dataclass
class SetStateCommand(Command):
@@ -182,15 +178,15 @@ def __post_init__(self):
) and not (self._token_topic_1 is not None and self._token_topic_2 is not None):
raise ValueError("topic, token, and outgoing must be set.")

if self.outgoing == True:
if self.outgoing is True:
assert self.topic and self.token
self._token_topic_1 = (
sha1(self.topic.encode()).digest()
if isinstance(self.topic, str)
else self.topic
)
self._token_topic_2 = self.token
elif self.outgoing == False:
elif self.outgoing is False:
assert self.topic and self.token
self._token_topic_1 = self.token
self._token_topic_2 = (
@@ -201,18 +197,14 @@ def __post_init__(self):
else:
assert self._token_topic_1 and self._token_topic_2
if len(self._token_topic_1) == 20: # SHA1 hash, topic
self.topic = (
KNOWN_TOPICS_LOOKUP[self._token_topic_1]
if self._token_topic_1 in KNOWN_TOPICS_LOOKUP
else self._token_topic_1
self.topic = KNOWN_TOPICS_LOOKUP.get(
JJTech0130 marked this conversation as resolved.
Show resolved Hide resolved
self._token_topic_1, self._token_topic_1
)
self.token = self._token_topic_2
self.outgoing = True
else:
self.topic = (
KNOWN_TOPICS_LOOKUP[self._token_topic_2]
if self._token_topic_2 in KNOWN_TOPICS_LOOKUP
else self._token_topic_2
self.topic = KNOWN_TOPICS_LOOKUP.get(
self._token_topic_2, self._token_topic_2
)
self.token = self._token_topic_1
self.outgoing = False
@@ -229,6 +221,27 @@ class SendMessageAck(Command):
unknown6: Optional[bytes] = fid(6, default=None)


@command
@dataclass
class ScopedTokenCommand(Command):
PacketType = Packet.Type.ScopedToken

token: bytes = fid(1)
topic: bytes = fid(2)
app_id: Optional[bytes] = fid(3, default=None)


@command
@dataclass
class ScopedTokenAck(Command):
PacketType = Packet.Type.ScopedTokenAck

status: int = fid(1)
scoped_token: bytes = fid(2)
topic: bytes = fid(3)
app_id: Optional[bytes] = fid(4, default=None)


@dataclass
class UnknownCommand(Command):
id: Packet.Type
@@ -240,7 +253,7 @@ def from_packet(cls, packet: Packet):

def to_packet(self) -> Packet:
return Packet(id=self.id, fields=self.fields)

def __repr__(self):
if self.id.value in [29, 30, 32]:
return f"UnknownCommand(id={self.id}, fields=[SUPPRESSED])"
@@ -259,6 +272,8 @@ def command_from_packet(packet: Packet) -> Command:
Packet.Type.SetState: SetStateCommand,
Packet.Type.SendMessage: SendMessageCommand,
Packet.Type.SendMessageAck: SendMessageAck,
Packet.Type.ScopedToken: ScopedTokenCommand,
Packet.Type.ScopedTokenAck: ScopedTokenAck,
# Add other mappings here...
}
command_class = command_classes.get(packet.id, None)
16 changes: 12 additions & 4 deletions pypush/apns/transport.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,8 @@ class Type(Enum):
KeepAlive = 12
KeepAliveAck = 13
NoStorage = 14
ScopedToken = 17
ScopedTokenAck = 18
SetState = 20
UNKNOWN = "Unknown"

@@ -38,20 +40,19 @@ def __new__(cls, value):
obj = object.__new__(cls)
obj._value_ = value
return obj

@classmethod
def _missing_(cls, value):
# Handle unknown values
instance = cls.UNKNOWN
instance._value_ = value # Assign the unknown value
return instance

def __str__(self):
if self is Packet.Type.UNKNOWN:
return f"Unknown({self._value_})"
return self.name


id: Type
fields: list[Field]

@@ -60,18 +61,25 @@ def fields_for_id(self, id: int) -> list[bytes]:


async def create_courier_connection(
sandbox: bool = False,
courier: str = "1-courier.push.apple.com",
) -> PacketStream:
context = ssl.create_default_context()
context.set_alpn_protocols(ALPN)

sni = "courier.sandbox.push.apple.com" if sandbox else "courier.push.apple.com"

# TODO: Verify courier certificate
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE

return PacketStream(
await anyio.connect_tcp(
courier, COURIER_PORT, ssl_context=context, tls_standard_compatible=False
courier,
COURIER_PORT,
ssl_context=context,
tls_standard_compatible=False,
tls_hostname=sni,
)
)

38 changes: 32 additions & 6 deletions pypush/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import contextlib
import logging
from asyncio import CancelledError

import anyio
import typer
from rich.logging import RichHandler
from typing_extensions import Annotated

from pypush import apns

from . import proxy as _proxy

logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s")
logging.basicConfig(level=logging.INFO, handlers=[RichHandler()], format="%(message)s")

app = typer.Typer()

@@ -22,12 +27,12 @@ def proxy(
Attach requires SIP to be disabled and to be running as root
"""

_proxy.main(attach)
with contextlib.suppress(CancelledError):
_proxy.main(attach)


@app.command()
def client(
def notifications(
topic: Annotated[str, typer.Argument(help="app topic to listen on")],
sandbox: Annotated[
bool, typer.Option("--sandbox/--production", help="APNs courier to use")
@@ -36,8 +41,29 @@ def client(
"""
Connect to the APNs courier and listen for app notifications on the given topic
"""
typer.echo("Running APNs client")
raise NotImplementedError("Not implemented yet")
logging.getLogger("httpx").setLevel(logging.WARNING)
with contextlib.suppress(CancelledError):
anyio.run(notifications_async, topic, sandbox)


async def notifications_async(topic: str, sandbox: bool):
async with apns.create_apns_connection(
*await apns.activate(),
JJTech0130 marked this conversation as resolved.
Show resolved Hide resolved
courier="1-courier.sandbox.push.apple.com"
if sandbox
else "1-courier.push.apple.com",
) as connection:
token = await connection.mint_scoped_token(topic)

async with connection.notification_stream(topic, token) as stream:
logging.info(
f"Listening for notifications on topic {topic} ({'sandbox' if sandbox else 'production'})"
)
logging.info(f"Token: {token.hex()}")

async for notification in stream:
await connection.ack(notification)
logging.info(notification.payload.decode())


def main():
3 changes: 2 additions & 1 deletion pypush/cli/_frida.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import frida
import logging

import frida


def attach_to_apsd() -> frida.core.Session:
frida.kill("apsd")
8 changes: 3 additions & 5 deletions pypush/cli/proxy.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
import logging
import ssl
import tempfile
from typing import Optional

import anyio
import anyio.abc
@@ -12,11 +11,10 @@
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from cryptography.hazmat.primitives.serialization import Encoding

# from pypush import apns
from pypush.apns import transport
from pypush.apns import protocol
from pypush.apns import protocol, transport

from . import _frida

@@ -71,7 +69,7 @@ async def handle(client: TLSStream):
else "1-courier.sandbox.push.apple.com"
)
name = f"prod-{connection_cnt}" if not sandbox else f"sandbox-{connection_cnt}"
async with await transport.create_courier_connection(forward) as conn:
async with await transport.create_courier_connection(sandbox, forward) as conn:
logging.debug("Connected to courier")
async with anyio.create_task_group() as tg:
tg.start_soon(forward_packets, client_pkt, conn, f"client-{name}")
Empty file removed pypush/cli/pushclient.py
Empty file.
75 changes: 75 additions & 0 deletions tests/assets/dev.jjtech.pypush.tests.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
Bag Attributes
friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests
localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3
subject=UID=dev.jjtech.pypush.tests, CN=Apple Sandbox Push Services: dev.jjtech.pypush.tests, OU=C4492JYJR3, C=US
issuer=CN=Apple Worldwide Developer Relations Certification Authority, OU=G4, O=Apple Inc., C=US
-----BEGIN CERTIFICATE-----
MIIGnzCCBYegAwIBAgIQRLQgelpeA0ozi3PDbx2ZmTANBgkqhkiG9w0BAQsFADB1
MUQwQgYDVQQDDDtBcHBsZSBXb3JsZHdpZGUgRGV2ZWxvcGVyIFJlbGF0aW9ucyBD
ZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTELMAkGA1UECwwCRzQxEzARBgNVBAoMCkFw
cGxlIEluYy4xCzAJBgNVBAYTAlVTMB4XDTI0MDUxNjAwMTUwM1oXDTI1MDYxNTAw
MTUwMlowgYoxJzAlBgoJkiaJk/IsZAEBDBdkZXYuamp0ZWNoLnB5cHVzaC50ZXN0
czE9MDsGA1UEAww0QXBwbGUgU2FuZGJveCBQdXNoIFNlcnZpY2VzOiBkZXYuamp0
ZWNoLnB5cHVzaC50ZXN0czETMBEGA1UECwwKQzQ0OTJKWUpSMzELMAkGA1UEBhMC
VVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQD3BvhGnrBtXpVLVvdi
HFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+pcYa
XK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt8J+Y
RHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0GVcI
0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0taBW
rdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6p8mC
TzZhAgMBAAGjggMTMIIDDzAMBgNVHRMBAf8EAjAAMB8GA1UdIwQYMBaAFFvZ+h3n
mhoLo5l2IlCGPpHIW3eoMHAGCCsGAQUFBwEBBGQwYjAtBggrBgEFBQcwAoYhaHR0
cDovL2NlcnRzLmFwcGxlLmNvbS93d2RyZzQuZGVyMDEGCCsGAQUFBzABhiVodHRw
Oi8vb2NzcC5hcHBsZS5jb20vb2NzcDAzLXd3ZHJnNDAzMIIBHgYDVR0gBIIBFTCC
AREwggENBgkqhkiG92NkBQEwgf8wgcMGCCsGAQUFBwICMIG2DIGzUmVsaWFuY2Ug
b24gdGhpcyBjZXJ0aWZpY2F0ZSBieSBhbnkgcGFydHkgYXNzdW1lcyBhY2NlcHRh
bmNlIG9mIHRoZSB0aGVuIGFwcGxpY2FibGUgc3RhbmRhcmQgdGVybXMgYW5kIGNv
bmRpdGlvbnMgb2YgdXNlLCBjZXJ0aWZpY2F0ZSBwb2xpY3kgYW5kIGNlcnRpZmlj
YXRpb24gcHJhY3RpY2Ugc3RhdGVtZW50cy4wNwYIKwYBBQUHAgEWK2h0dHBzOi8v
d3d3LmFwcGxlLmNvbS9jZXJ0aWZpY2F0ZWF1dGhvcml0eS8wEwYDVR0lBAwwCgYI
KwYBBQUHAwIwMgYDVR0fBCswKTAnoCWgI4YhaHR0cDovL2NybC5hcHBsZS5jb20v
d3dkcmc0LTMuY3JsMB0GA1UdDgQWBBQKyU1l8TlEc1+oBby5AEcULBKa8zAOBgNV
HQ8BAf8EBAMCB4Awgb8GCiqGSIb3Y2QGAwYEgbAwga0MF2Rldi5qanRlY2gucHlw
dXNoLnRlc3RzMAcMBXRvcGljDBxkZXYuamp0ZWNoLnB5cHVzaC50ZXN0cy52b2lw
MAYMBHZvaXAMJGRldi5qanRlY2gucHlwdXNoLnRlc3RzLmNvbXBsaWNhdGlvbjAO
DAxjb21wbGljYXRpb24MIGRldi5qanRlY2gucHlwdXNoLnRlc3RzLnZvaXAtcHR0
MAsMCS52b2lwLXB0dDAQBgoqhkiG92NkBgMBBAIFADANBgkqhkiG9w0BAQsFAAOC
AQEAwQac2q1BMnAH1vdZgfDunc+b7SKO6rJIG6w/wl4211YyNBBS5oabQnQDfB8y
8iOeWnoWXry60gI2fwWN/rRaQn4QCy72jNeTGz/T/s2jwoGj89114JjcBhRAHvQl
/HN4QjSt5rWVRcxTE4cKKbJIqVCm7Uq9VROgbxXrmsZsRnyk1ASvLGboibtGbmty
wmXZWns5NXNDbv1wP+PF5HSFXtDWodPYnhvzJe0s9lRvo4yGAt1KL5mNaZM3kKp0
74kdzKK/iT7954EQK4ZWPQbDnS1A+/BzHQjK0rWTwjDQkbKvNE9bb+KJbNHH3+DX
5s0ybZYoG5meGKUplwu7A2bfFw==
-----END CERTIFICATE-----
Bag Attributes
friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests Private Key
localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3
Key Attributes: <No Attributes>
-----BEGIN PRIVATE KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQD3BvhGnrBtXpVL
VvdiHFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+
pcYaXK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt
8J+YRHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0
GVcI0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0
taBWrdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6
p8mCTzZhAgMBAAECggEBAKADb8eu+3GdFvAagVyYI5wq5Vik1uu0vFKD+cfFeQQT
bCTxe/TTkAYSybwJEb0Zjy0spE1rgfzHbTFsiIqDBs1TqsZnPuPEhrzXMfVcyTqt
I3yjlMAFPeAkEqcfmdUiPgp64zHHNmI8lBSoDXlAwypY6PnwArtAI3MItTFcElhX
gWB44xVGuJRjRP4UVqXg0ML/Ic2yuYT9DRsDRilYhm8RGRSHkdZKdzCicMZcLtC7
bs6/evmIrk9V5AzF6YiXlfT0dOp6yy9mFwhLljXF3Z2/LdrOTAmhLPQRMbUrJrcW
ZPd0kMybGIlEoprQEA/6nZkdtIiDo2OJtufCs8g+nJECgYEA/+v4uTJzEI1igKOB
myJtADECZAsJUaJaKSAM7VHn1hNOKgNLhUHOuroWvIWEhEomWeMvCbZIG42eOwNW
BXGtG7ruT79E6655dljU6E/029FaxONqXXCTD9ZPh031R293KcydMwgBJJ0pvFJE
14HWmMRAG0auPygMRhXubXU1ndMCgYEA9xpNWrl9poTjsZDNqvu60nYcq0W1escw
ovmb87uxZ5u8fC8T1F3AVMYj4v0dTyA4F0mZenY+nri/hJBuanWVxa5Liu0fGnBr
tEa2rzCMaajoDTNMKSygFz6CIMZbbZhozy0+9DHcRcC6b2UtIgB/+/ZQtrTvQ8Ea
i6viarkq1nsCgYBznYAM8mynEqhoYvV/RyslBf8FgTLhjU3b/F26rODmhmwucLSi
a9tf4ge5fTwjo3f17btnUND8mZrdICGxbex9dZKJtmgFbRn0TCdLGCwPTmIKRo7b
zaqyYeglwSNI9WNJH+X4kuopR1L+f9AX59ExzJ8Fc4XuhEIfO3MuQeBJ/wKBgQDa
8AgH0X/+EZJ42rcPvxiprxL5wbrpPSHf1M+T5gJqrXcUhNXJ/QMTWbekP+Y/HGn2
YDTHZ4tWMJUoTJw4YVTBoQu33R8I2wDi6yCkGpzeZVStlXzuomZ6Ed1UUsvhT//V
SN6VmLP1ba0CVB/oF49OXNDpAWlZm/f8NuBW9Rd6jwKBgQDi495IOjLJ8SvWRJLT
c9AUmO7IVgipWvr51cF9IYxkzXIVIQIh1usy2NsrBxshAD+FbbWFVBfoptdKBZVK
J8u+Ou4gTxs8SdGKGZWZpUMEKJbPsq8lE2aU3mBXiWcFRxYpu+n7nKap0Lla/xBD
v77FY1M3FxGR6rNqPJQ9rRLFbA==
-----END PRIVATE KEY-----
61 changes: 41 additions & 20 deletions tests/test_apns.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import pytest
from pypush import apns
import asyncio

# from aioapns import *
import logging
import uuid
import anyio

# from pypush.apns import _util
# from pypush.apns import albert, lifecycle, protocol
from pypush import apns
from pathlib import Path

import logging
import httpx
import pytest
from rich.logging import RichHandler

from pypush import apns

logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s")


@@ -26,17 +21,43 @@ async def test_activate():

@pytest.mark.asyncio
async def test_lifecycle_2():
async with apns.create_apns_connection(
certificate, key, courier="localhost"
) as connection:
await connection.receive(
apns.protocol.ConnectAck
) # Just wait until the initial connection is established. Don't do this in real code plz.
async with apns.create_apns_connection(certificate, key) as _:
pass


ASSETS_DIR = Path(__file__).parent / "assets"


async def send_test_notification(device_token, payload=b"hello, world"):
async with httpx.AsyncClient(
cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True
) as client:
# Use the certificate and key from above
response = await client.post(
f"https://api.sandbox.push.apple.com/3/device/{device_token}",
content=payload,
headers={
"apns-topic": "dev.jjtech.pypush.tests",
"apns-push-type": "alert",
"apns-priority": "10",
},
)
assert response.status_code == 200


@pytest.mark.asyncio
async def test_shorthand():
async def test_scoped_token():
async with apns.create_apns_connection(
*await apns.activate(), courier="localhost"
*await apns.activate(), sandbox=True
) as connection:
await connection.receive(apns.protocol.ConnectAck)
token = await connection.mint_scoped_token("dev.jjtech.pypush.tests")

test_message = f"test-message-{uuid.uuid4().hex}"

await send_test_notification(token.hex(), test_message.encode())

await connection.expect_notification(
"dev.jjtech.pypush.tests",
token,
lambda c: c if c.payload == test_message.encode() else None,
)