Skip to content

Commit

Permalink
add functools.lru_cache to knp extensions. (#39)
Browse files Browse the repository at this point in the history
`sent._.knp_tag_spans` and simlar extensions are now cached.
  • Loading branch information
tamuhey authored Apr 9, 2020
1 parent a688fae commit ba0161d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 37 deletions.
6 changes: 6 additions & 0 deletions camphr/pipelines/knp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Defines KNP pipelines."""
import functools
import re
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple

Expand Down Expand Up @@ -103,6 +104,7 @@ def __call__(self, doc: Doc) -> Doc:


@curry
@functools.lru_cache()
def token_to_knp_span(type_: str, token: Token) -> Span:
"""Returns the knp span containing the token."""
assert type_ != MORPH
Expand All @@ -113,6 +115,7 @@ def token_to_knp_span(type_: str, token: Token) -> Span:


@curry
@functools.lru_cache()
def get_knp_span(type_: str, span: Span) -> List[Span]:
"""Get knp tag or bunsetsu list"""
assert type_ != MORPH
Expand All @@ -134,6 +137,7 @@ def get_knp_span(type_: str, span: Span) -> List[Span]:
return res


@functools.lru_cache()
def get_knp_element_id(elem) -> int:
from pyknp import Morpheme, Bunsetsu, Tag

Expand All @@ -154,6 +158,7 @@ def get_all_knp_features_from_sents(


@curry
@functools.lru_cache()
def get_knp_parent(type_: L_KNP_OBJ, span: Span) -> Optional[Span]:
tag_or_bunsetsu = span._.get(getattr(KNP_USER_KEYS, type_).element)
if not tag_or_bunsetsu:
Expand All @@ -166,6 +171,7 @@ def get_knp_parent(type_: L_KNP_OBJ, span: Span) -> Optional[Span]:


@curry
@functools.lru_cache()
def get_knp_children(type_: L_KNP_OBJ, span: Span) -> List[Span]:
tag_or_bunsetsu = span._.get(getattr(KNP_USER_KEYS, type_).element)
if not tag_or_bunsetsu:
Expand Down
79 changes: 42 additions & 37 deletions tests/pipelines/knp/test_knp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,37 @@ def nlp():
@pytest.mark.parametrize("text,ents", TESTCASES)
def test_knp_entity_extractor(nlp: Language, text: str, ents: Tuple[str]):
doc: Doc = nlp(text)
assert len(doc.ents) == len(ents)
for s, expected_ent in zip(doc.ents, ents):
assert s.text == expected_ent[0]
assert s.label_ == expected_ent[1]
for _ in range(2): # loop 2 times to validate cache
assert len(doc.ents) == len(ents)
for s, expected_ent in zip(doc.ents, ents):
assert s.text == expected_ent[0]
assert s.label_ == expected_ent[1]


@pytest.mark.parametrize("text", TEXTS)
def test_knp_span_getter(nlp: Language, text: str):
doc: Doc = nlp(text)
for sent in doc.sents:
blist = sent._.get(KNP_USER_KEYS.bunsetsu.list_)
text = "".join(b.midasi for b in blist)
assert text == sent.text
assert all(
[
b.midasi == s.text
for b, s in itertools.zip_longest(
blist, sent._.get(KNP_USER_KEYS.bunsetsu.spans)
)
]
)
assert all(
[
t.midasi == s.text
for t, s in itertools.zip_longest(
blist.tag_list(), sent._.get(KNP_USER_KEYS.tag.spans)
)
]
)
for _ in range(2): # loop 2 times to validate cache
for sent in doc.sents:
blist = sent._.get(KNP_USER_KEYS.bunsetsu.list_)
text = "".join(b.midasi for b in blist)
assert text == sent.text
assert all(
[
b.midasi == s.text
for b, s in itertools.zip_longest(
blist, sent._.get(KNP_USER_KEYS.bunsetsu.spans)
)
]
)
assert all(
[
t.midasi == s.text
for t, s in itertools.zip_longest(
blist.tag_list(), sent._.get(KNP_USER_KEYS.tag.spans)
)
]
)


@pytest.mark.parametrize(
Expand All @@ -67,10 +69,11 @@ def test_knp_span_getter(nlp: Language, text: str):
)
def test_knp_parent_getter(nlp: Language, text: str, parents: List[List[str]]):
doc: Doc = nlp(text)
for sent, pl in zip(doc.sents, parents):
spans = sent._.get(KNP_USER_KEYS.tag.spans)
ps = [span._.get(KNP_USER_KEYS.tag.parent) for span in spans]
assert [s.text if s else "" for s in ps] == [p for p in pl]
for _ in range(2): # loop 2 times to validate cache
for sent, pl in zip(doc.sents, parents):
spans = sent._.get(KNP_USER_KEYS.tag.spans)
ps = [span._.get(KNP_USER_KEYS.tag.parent) for span in spans]
assert [s.text if s else "" for s in ps] == [p for p in pl]


@pytest.mark.parametrize(
Expand All @@ -89,11 +92,12 @@ def test_knp_children_getter(
nlp: Language, text: str, children_list: List[List[List[str]]]
):
doc: Doc = nlp(text)
for sent, children_texts in zip(doc.sents, children_list):
spans = sent._.get(KNP_USER_KEYS.tag.spans)
children = [span._.get(KNP_USER_KEYS.tag.children) for span in spans]
children = [[cc.text for cc in c] for c in children]
assert children == children_texts
for _ in range(2): # loop 2 times to validate cache
for sent, children_texts in zip(doc.sents, children_list):
spans = sent._.get(KNP_USER_KEYS.tag.spans)
children = [span._.get(KNP_USER_KEYS.tag.children) for span in spans]
children = [[cc.text for cc in c] for c in children]
assert children == children_texts


@pytest.mark.parametrize("text", ["(※"])
Expand Down Expand Up @@ -122,7 +126,8 @@ def test_knp_doc_getter(nlp: Language):
)
def test_tag_or_bunsetsu_from_token(nlp: Language, text: str):
doc = nlp(text)
for k in ["tag", "bunsetsu"]:
for span in doc._.get(getattr(KNP_USER_KEYS, k).spans):
for token in span:
assert token._.get(getattr(KNP_USER_KEYS.morph, k)) == span
for _ in range(2): # loop 2 times to validate cache
for k in ["tag", "bunsetsu"]:
for span in doc._.get(getattr(KNP_USER_KEYS, k).spans):
for token in span:
assert token._.get(getattr(KNP_USER_KEYS.morph, k)) == span

0 comments on commit ba0161d

Please sign in to comment.