Skip to content

Commit

Permalink
Setting line length to 120
Browse files Browse the repository at this point in the history
  • Loading branch information
ChWick committed May 17, 2021
1 parent fcef59c commit d12c938
Show file tree
Hide file tree
Showing 99 changed files with 481 additions and 1,542 deletions.
4 changes: 1 addition & 3 deletions calamari_ocr/ocr/augmentation/data_augmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def augment_data(self, data, gt_txt, n_augmentations):
def augment_data_tuple(self, t):
return self.augment_data(*t)

def augment_datas(
self, datas, gt_txts, n_augmentations, processes=1, progress_bar=False
):
def augment_datas(self, datas, gt_txts, n_augmentations, processes=1, progress_bar=False):
if n_augmentations < 0 or not isinstance(n_augmentations, int):
raise ValueError("Number of augmentation must be an integer >= 0")

Expand Down
16 changes: 4 additions & 12 deletions calamari_ocr/ocr/augmentation/dataaugmentationparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,13 @@ def to_dict(self):
@staticmethod
def from_factor(n):
if n >= 1:
return DataAugmentationAmount(
DataAugmentationAmountReference.ABSOLUTE, int(n), None
)
return DataAugmentationAmount(DataAugmentationAmountReference.ABSOLUTE, int(n), None)
elif n > 0:
return DataAugmentationAmount(
DataAugmentationAmountReference.PERCENTAGE, None, n
)
return DataAugmentationAmount(DataAugmentationAmountReference.PERCENTAGE, None, n)
elif n == 0:
return DataAugmentationAmount(
DataAugmentationAmountReference.PERCENTAGE, 0, 0
)
return DataAugmentationAmount(DataAugmentationAmountReference.PERCENTAGE, 0, 0)
else:
raise ValueError(
"Factor must be between (0, +infinity) but got {}".format(n)
)
raise ValueError("Factor must be between (0, +infinity) but got {}".format(n))

def __init__(
self,
Expand Down
26 changes: 6 additions & 20 deletions calamari_ocr/ocr/dataset/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
class CodecConstructionParams:
keep_loaded: bool = field(
default=True,
metadata=pai_meta(
help="Fully include the codec of the loaded model to the new codec"
),
metadata=pai_meta(help="Fully include the codec of the loaded model to the new codec"),
)
auto_compute: bool = field(
default=True,
Expand All @@ -34,14 +32,10 @@ class CodecConstructionParams:
)
include_files: List[str] = field(
default_factory=list,
metadata=pai_meta(
help="Whitelist of txt files that may not be removed on restoring a model"
),
metadata=pai_meta(help="Whitelist of txt files that may not be removed on restoring a model"),
)

resolved_include_chars: Set[str] = field(
default_factory=set, metadata=pai_meta(mode="ignore")
)
resolved_include_chars: Set[str] = field(default_factory=set, metadata=pai_meta(mode="ignore"))

def __post_init__(self):
# parse whitelist
Expand All @@ -60,9 +54,7 @@ def __post_init__(self):
@pai_dataclass(no_assign_to_unknown=False)
@dataclass
class Codec:
charset: List[
str
] # this filed will be used to store and load a the Codec from json
charset: List[str] # this filed will be used to store and load a the Codec from json

@staticmethod
def from_input_dataset(
Expand All @@ -89,9 +81,7 @@ def from_input_dataset(
return Codec(sorted(list(chars)))

@staticmethod
def from_texts(
texts: List[str], codec_construction_params: CodecConstructionParams
):
def from_texts(texts: List[str], codec_construction_params: CodecConstructionParams):
"""Compute a codec from given text
First computes a set of all available characters.
Expand All @@ -107,11 +97,7 @@ def from_texts(
-------
Codec based on the set of characters + whitelist
"""
chars = (
set()
if codec_construction_params.include is None
else set(codec_construction_params.include)
)
chars = set() if codec_construction_params.include is None else set(codec_construction_params.include)

for text in texts:
for c in text:
Expand Down
8 changes: 2 additions & 6 deletions calamari_ocr/ocr/dataset/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def create_pipeline(
from calamari_ocr.ocr import Codec

this_dir = os.path.dirname(os.path.realpath(__file__))
base_path = os.path.abspath(
os.path.join(this_dir, "..", "..", "test", "data", "uw3_50lines", "train")
)
base_path = os.path.abspath(os.path.join(this_dir, "..", "..", "test", "data", "uw3_50lines", "train"))

fdr = FileDataParams(
num_processes=8,
Expand All @@ -113,9 +111,7 @@ def create_pipeline(
)

params = DataParams(
codec=Codec(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,:;-?+=_()*{}[]`@#$%^&'\""
),
codec=Codec("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,:;-?+=_()*{}[]`@#$%^&'\""),
downscale_factor=4,
line_height=48,
pre_proc=SequentialProcessorPipelineParams(
Expand Down
40 changes: 10 additions & 30 deletions calamari_ocr/ocr/dataset/datareader/abbyy/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ class Abbyy(CalamariDataGeneratorParams):
xml_files: List[str] = field(default_factory=list)
gt_extension: str = field(
default=".abbyy.xml",
metadata=pai_meta(
help="Default extension of the gt files (expected to exist in same dir)"
),
metadata=pai_meta(help="Default extension of the gt files (expected to exist in same dir)"),
)
binary: bool = False
pred_extension: str = field(
Expand All @@ -50,9 +48,7 @@ def select(self, indices: List[int]):

def to_prediction(self):
pred = deepcopy(self)
pred.xml_files = [
split_all_ext(f)[0] + self.pred_extension for f in self.xml_files
]
pred.xml_files = [split_all_ext(f)[0] + self.pred_extension for f in self.xml_files]
return pred

@staticmethod
Expand All @@ -63,9 +59,7 @@ def prepare_for_mode(self, mode: PipelineMode):
self.images = sorted(glob_all(self.images))
self.xml_files = sorted(glob_all(self.xml_files))
if not self.xml_files:
self.xml_files = [
split_all_ext(f)[0] + self.gt_extension for f in self.images
]
self.xml_files = [split_all_ext(f)[0] + self.gt_extension for f in self.images]
if not self.images:
self.xml_files = sorted(glob_all(self.xml_files))
self.images = [None] * len(self.xml_files)
Expand All @@ -79,9 +73,7 @@ def __init__(
):
super().__init__(mode, params)

self.book = XMLReader(
self.params.images, self.params.xml_files, self.params.skip_invalid
).read()
self.book = XMLReader(self.params.images, self.params.xml_files, self.params.skip_invalid).read()

for p, page in enumerate(self.book.pages):
for l, line in enumerate(page.getLines()):
Expand All @@ -90,9 +82,7 @@ def __init__(
{
"image_path": page.imgFile,
"xml_path": page.xmlFile,
"id": "{}_{}_{}_{}".format(
split_all_ext(page.xmlFile or page.imgFile)[0], p, l, f
),
"id": "{}_{}_{}_{}".format(split_all_ext(page.xmlFile or page.imgFile)[0], p, l, f),
"line": line,
"format": fo,
}
Expand All @@ -104,12 +94,8 @@ def store_text_prediction(self, sentence, sample_id, output_dir):
sample["format"].text = sentence

def store(self):
for page in tqdm(
self.book.pages, desc="Writing Abbyy files", total=len(self.book.pages)
):
XMLWriter.write(
page, split_all_ext(page.xmlFile)[0] + self.params.pred_extension
)
for page in tqdm(self.book.pages, desc="Writing Abbyy files", total=len(self.book.pages)):
XMLWriter.write(page, split_all_ext(page.xmlFile)[0] + self.params.pred_extension)

def _sample_iterator(self):
return zip(self.params.images, self.params.xml_files)
Expand All @@ -127,17 +113,13 @@ def _generate_epoch(self, text_only) -> Generator[InputSample, None, None]:
for l, line in enumerate(page.getLines()):
for f, fo in enumerate(line.formats):
fold_id += 1
sample_id = "{}_{}_{}_{}".format(
split_all_ext(page.xmlFile or page.imgFile)[0], p, l, f
)
sample_id = "{}_{}_{}_{}".format(split_all_ext(page.xmlFile or page.imgFile)[0], p, l, f)
text = None
if self.mode in TARGETS_PROCESSOR:
text = fo.text

if text_only:
yield InputSample(
None, text, SampleMeta(id=sample_id, fold_id=fold_id)
)
yield InputSample(None, text, SampleMeta(id=sample_id, fold_id=fold_id))

else:
cut_img = None
Expand All @@ -158,9 +140,7 @@ def _generate_epoch(self, text_only) -> Generator[InputSample, None, None]:
constant_values=cut_img.max(),
)

yield InputSample(
cut_img, text, SampleMeta(id=sample_id, fold_id=fold_id)
)
yield InputSample(cut_img, text, SampleMeta(id=sample_id, fold_id=fold_id))

def _load_sample(self, sample, text_only) -> Generator[InputSample, None, None]:
raise NotImplementedError
8 changes: 1 addition & 7 deletions calamari_ocr/ocr/dataset/datareader/abbyy/xml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,7 @@ def __init__(self, baseline: int, rect: Rect):
self.formats = []

def __str__(self):
return (
"Line:[baseline='"
+ self.baseline.__str__()
+ "', "
+ self.rect.__str__()
+ "]"
)
return "Line:[baseline='" + self.baseline.__str__() + "', " + self.rect.__str__() + "]"


class Format:
Expand Down
38 changes: 9 additions & 29 deletions calamari_ocr/ocr/dataset/datareader/abbyy/xml/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,15 @@ def read(self) -> Book:
if xmlfile:
if not os.path.exists(xmlfile):
if not self.skip_invalid:
raise XMLParseError(
"The abbyy xml file {} does not exist".format(xmlfile)
)
raise XMLParseError("The abbyy xml file {} does not exist".format(xmlfile))
else:
toremove.append(i)
continue

if imgfile:
if not os.path.exists(imgfile):
if not self.skip_invalid:
raise XMLParseError(
"The image file {} does not exist".format(imgfile)
)
raise XMLParseError("The image file {} does not exist".format(imgfile))
else:
toremove.append(i)
continue
Expand All @@ -71,9 +67,7 @@ def read(self) -> Book:
except XMLParseError as e:
logger.exception(e)
if self.skip_invalid:
logger.warning(
"Exception during XMLParsing ignored. Skipping example."
)
logger.warning("Exception during XMLParsing ignored. Skipping example.")
toremove.append(i)
continue
else:
Expand Down Expand Up @@ -106,9 +100,7 @@ def requireAttr(node, attrs):
for attr in attrs:
a[attr] = node.get(attr)
if a[attr] is None:
raise XMLParseError(
"Missing required attribute {} on node {}".format(attr, node)
)
raise XMLParseError("Missing required attribute {} on node {}".format(attr, node))

return a

Expand All @@ -118,8 +110,7 @@ def parseXMLfile(self, imgfile, xmlfile):
tree = ET.parse(xmlfile)
except ET.ParseError as e:
raise XMLParseError(
"The xml file '" + xmlfile + "' couldn't be read because of a "
"syntax error in the xml file. " + e.msg
"The xml file '" + xmlfile + "' couldn't be read because of a " "syntax error in the xml file. " + e.msg
)

root = tree.getroot()
Expand All @@ -128,9 +119,7 @@ def parseXMLfile(self, imgfile, xmlfile):
raise XMLParseError("The xml file '" + xmlfile + "' is empty.")

for pagecount, pageNode in enumerate(root):
a = XMLReader.requireAttr(
pageNode, ["width", "height", "resolution", "originalCoords"]
)
a = XMLReader.requireAttr(pageNode, ["width", "height", "resolution", "originalCoords"])
page = Page(
a["width"],
a["height"],
Expand All @@ -149,18 +138,13 @@ def parseXMLfile(self, imgfile, xmlfile):
# Reads rectangle data and controls if they are empty
name = blockNode.get("blockName")

block = Block(
type, name, XMLReader.parseRect(blockNode, required=False)
)
block = Block(type, name, XMLReader.parseRect(blockNode, required=False))

for textNode in blockNode:

# Again only text nodes will be considered

if (
textNode.tag
== "{http://www.abbyy.com/FineReader_xml/FineReader10-schema-v1.xml}text"
):
if textNode.tag == "{http://www.abbyy.com/FineReader_xml/FineReader10-schema-v1.xml}text":
for parNode in textNode:
align = parNode.get("align")
startIndent = parNode.get("startIndent")
Expand All @@ -178,11 +162,7 @@ def parseXMLfile(self, imgfile, xmlfile):
maxCount = 0
for formNode in lineNode:
countChars = 0
if (
formNode.text is None
or formNode.text == "\n"
or formNode.text == ""
):
if formNode.text is None or formNode.text == "\n" or formNode.text == "":
for charNode in formNode:
text += str(charNode.text)
countChars = countChars + 1
Expand Down
Loading

0 comments on commit d12c938

Please sign in to comment.