From 8799cfc9b062aeaa8d9608da4b3ef3ef3970d7d4 Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Mon, 30 May 2022 11:06:37 +0200 Subject: [PATCH] Improve virtual relation API --- binder/views.py | 29 ++++++++++++++++++---------- tests/test_virtual_relations.py | 34 +++++++++++++++++++++++++++++++++ tests/testapp/views/animal.py | 21 ++++++++++++++++++++ 3 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 tests/test_virtual_relations.py diff --git a/binder/views.py b/binder/views.py index 2c419db4..0eef5b19 100644 --- a/binder/views.py +++ b/binder/views.py @@ -1048,7 +1048,7 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): result = {} singular_fields = set() - rel_ids_by_field_by_id = defaultdict(lambda: defaultdict(list)) + rel_ids_by_field_by_id = {} virtual_fields = set() for field in with_map: @@ -1083,7 +1083,7 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): raise BinderRequestError('Annotation for virtual relation {{{}}}.{{{}}} is {{{}}}, but no method by that name exists.'.format( self.model.__name__, field, virtual_annotation )) - rel_ids_by_field_by_id[field] = func(request, pks, q) + field_rel_ids = func(request, pks, q) # Actual relation else: f = self.model._meta.get_field(field) @@ -1112,8 +1112,16 @@ def _get_with_ids(self, pks, request, include_annotations, with_map, where_map): .distinct() ) + field_rel_ids = defaultdict(list) for pk, rel_pk in query: - rel_ids_by_field_by_id[field][pk].append(rel_pk) + field_rel_ids[pk].append(rel_pk) + + # Make sure we always have a result which has exactly the requested + # pks as keys + rel_ids_by_field_by_id[field] = { + pk: list(field_rel_ids.get(pk, [])) + for pk in pks + } for field, sub_fields in with_map.items(): next = self._follow_related(field)[0].model @@ -1413,14 +1421,15 @@ def order_by(self, queryset, request): def _annotate_obj_with_related_withs(self, obj, field_results): for (w, (view, ids_dict, is_singular)) in field_results.items(): - if '.' not in w: - if is_singular: - try: - obj[w] = list(ids_dict[obj['id']])[0] - except IndexError: - obj[w] = None + if '.' not in w and obj['id'] in ids_dict: + if not is_singular: + obj[w] = ids_dict[obj['id']] + elif len(ids_dict[obj['id']]) == 0: + obj[w] = None + elif len(ids_dict[obj['id']]) == 1: + obj[w] = ids_dict[obj['id']][0] else: - obj[w] = list(ids_dict[obj['id']]) + raise ValueError(f'multiple values for singular relation {w}: {obj["id"]} -> {ids_dict[obj["id"]]}') def _generate_meta(self, include_meta, queryset, request, pk=None): diff --git a/tests/test_virtual_relations.py b/tests/test_virtual_relations.py new file mode 100644 index 00000000..5ffbeeb9 --- /dev/null +++ b/tests/test_virtual_relations.py @@ -0,0 +1,34 @@ +import json + +from django.contrib.auth.models import User +from django.test import TestCase + +from .testapp.models import Zoo, Animal + + +class VirtualRelationTestCase(TestCase): + + def setUp(self): + user = User(username='testuser', is_active=True, is_superuser=True) + user.set_password('test') + user.save() + + self.client.login(username='testuser', password='test') + + def test_virtual_relation(self): + pride_rock = Zoo.objects.create(name='Pride Rock') + simba = Animal.objects.create(zoo=pride_rock, name='Simba') + nala = Animal.objects.create(zoo=pride_rock, name='Nala') + + res = self.client.get(f'/zoo/?with=animals.neighbours&where=animals(id={simba.id})') + self.assertEqual(res.status_code, 200) + res = json.loads(res.content) + + animals_by_id = { + obj['id']: obj + for obj in res['with']['animal'] + } + + self.assertEqual(set(animals_by_id), {simba.id, nala.id}) + self.assertEqual(animals_by_id[simba.id]['neighbours'], [nala.id]) + self.assertNotIn('neighbours', animals_by_id[nala.id]) diff --git a/tests/testapp/views/animal.py b/tests/testapp/views/animal.py index a49044be..9c537ca7 100644 --- a/tests/testapp/views/animal.py +++ b/tests/testapp/views/animal.py @@ -1,3 +1,5 @@ +from django.db.models import F + from binder.views import ModelView from ..models import Animal @@ -8,3 +10,22 @@ class AnimalView(ModelView): m2m_fields = ['costume'] searches = ['name__icontains'] transformed_searches = {'zoo_id': int} + + virtual_relations = { + 'neighbours': { + 'model': Animal, + 'annotation': '_virtual_neighbours', + 'singular': False, + }, + } + + def _virtual_neighbours(self, request, pks, q): + neighbours = {} + for pk, neighbour_pk in ( + Animal.objects + .filter(q, zoo__animals__pk__in=pks) + .values_list('zoo__animals__pk', 'pk') + ): + if neighbour_pk != pk: + neighbours.setdefault(pk, []).append(neighbour_pk) + return neighbours