Skip to content

Commit

Permalink
"0.2.3" Optimization and bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Warrfie committed Nov 23, 2023
1 parent 23aefff commit 9d2b6dd
Show file tree
Hide file tree
Showing 12 changed files with 335 additions and 246 deletions.
3 changes: 2 additions & 1 deletion combidata/classes/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def __init__(self, case: dict, field_name: str, field_mode: str):

self.additional_fields = self.form_additional_fields(case)


def __repr__(self):
return str(vars(self))

def hand_requirements(self, requirements):
if isinstance(requirements, dict):
Expand Down
39 changes: 22 additions & 17 deletions combidata/classes/combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,24 @@
import traceback


def step_not_done(current_step_name, combi):
if isinstance(combi, list):
for combination in combi:
if combination.step_done != current_step_name and combination.step_done != "STOP":
return True
def current_workflow(workflow, is_all_steps=False):
from combidata import Process
if isinstance(workflow[0], Process):
for process in workflow:
yield process
elif is_all_steps:
for stage in workflow:
for process in stage:
yield process
else:
if combi.step_done != current_step_name and combi.step_done != "STOP":
return True
for process in workflow[0]:
yield process
workflow.pop(0)


def step_not_done(current_step_name, combi):
if combi.step_done != current_step_name and not isinstance(combi.step_done, Exception):
return True
return False


Expand All @@ -36,7 +46,7 @@ class Combination:

step_done = None # last passed step

def __init__(self, case, workflow, init_lib, template, tools, logger, generator_id):
def __init__(self, case, workflow, init_lib, template, tools, logger, generator_id, types_for_generation):
self.init_lib = copy.deepcopy(init_lib)
self.main_case = case
self.template = template
Expand All @@ -49,22 +59,20 @@ def __init__(self, case, workflow, init_lib, template, tools, logger, generator_

self.cache = {}

self.workflow = workflow
self.workflow = copy.deepcopy(workflow)

self.init_lib[self.main_case.field_name] = {self.main_case.field_mode: self.main_case}
self.types_for_generation = types_for_generation

def run(self):
self.workflow = list(self.workflow) if isinstance(self.workflow, list) else self.workflow # todo beautify
workflow = self.workflow.pop(0) if isinstance(self.workflow, list) else self.workflow
for current_step in workflow:
for current_step in current_workflow(self.workflow):
while step_not_done(current_step.name, self):
if self.step_done != current_step.name:
if self.logger:
self.logger.start_step(self.generator_id, current_step.name)
try:
current_step.activate(self)
except Exception as e:
self.step_done = "STOP"
self.step_done = e
if self.logger:
temp_exep = f"An exception occurred: {type(e).__name__}. "
temp_exep += f"Error message: {str(e)}. "
Expand All @@ -75,9 +83,6 @@ def run(self):
line_number = last_traceback.lineno
temp_exep += f"Occurred at: {file_name}:{line_number}. "
self.logger.end_step(self.generator_id, temp_exep)
else:
raise e
else:
if self.logger:
self.logger.end_step(self.generator_id)

235 changes: 87 additions & 148 deletions combidata/classes/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,18 @@
import copy
import random
import traceback
from pprint import pprint

from combidata.classes.case import Case
from combidata.classes.combination import Combination, step_not_done


def crop_types(current_dict, poss_types):
for unit, modes in current_dict.items():
for mode in list(modes.keys()):
if current_dict[unit][mode].type_of_case not in poss_types:
del current_dict[unit][mode]


def can_combine(neutral_lib, case):
for field, modes in neutral_lib.items():
for mode in modes.values():
if (case.field_name in mode.requirements and case.field_mode in mode.requirements[
case.field_name]) or mode.field_name == case.field_name:
break
else:
return False
return True


def form_template(lib):
template = {}
for field, modes in lib.items():
template[field] = {}
for mode in modes:
template[field][mode] = copy.deepcopy(
lib[field][mode]) # TODO fix_it copy.deepcopy(lib[field][mode]) is tooo dum
for seed_field, seed_modes in lib.items():
if seed_field != field:
template[field][mode].requirements[seed_field] = set(seed_modes)
return template
from combidata.classes.combination import Combination, current_workflow
from combidata.classes.mul_dim_graph import MDG
from combidata.funcs.exeptions import CombinatoricsError


def check_all_names(init_lib):
name_set = set()
for cases in init_lib["cases"].values():
for case in cases.values():
assert case["name"] not in name_set, case["name"] + " - is not unique"
name_set.update(set(case["name"]))


def extend_dict(input_dict, final_key_count):
output_dict = {}
dict_keys = list(input_dict.keys())
key_count = len(dict_keys)

i = 0
added_key_count = 0
shuffle_point = 0 if key_count > final_key_count else (final_key_count // key_count) * key_count

while added_key_count < final_key_count:
if i % key_count == 0 and i == shuffle_point:
random.shuffle(dict_keys)
key = dict_keys[i % key_count]
value = input_dict[key]

if i < key_count:
output_dict[key] = copy.deepcopy(value)
else:
extended_key = f"{key}[{i // key_count}]"
output_dict[extended_key] = copy.deepcopy(value)

i += 1
added_key_count += 1

return output_dict
name_set.add(case["name"])


class DataGenerator:
Expand Down Expand Up @@ -111,6 +53,7 @@ def __init__(self, library: dict,
logger=None,
generator_id: str = None):

self.combinations = None
assert amount is None or (isinstance(amount, int) and amount > 0), "amount must be integer > 0"
assert banned_fields is None or isinstance(banned_fields, list), "banned_fields must be list instance"
assert possible_fields is None or isinstance(possible_fields, list), "possible_fields must be list instance"
Expand All @@ -120,80 +63,111 @@ def __init__(self, library: dict,
self.modes_for_gen = self.form_modes_for_gen(possible_modes)
self.init_lib = self.form_init_lib(library)
self.dell_fields(possible_fields, banned_fields)
self.template = library["template"]
self.template = library.get("template")
self.tools = library.get("tools")
self.logger = logger
self.generator_id = generator_id

assert (logger and generator_id) or logger is None, "You must use logger and generator_id"

self.type_of_cases = type_of_cases if type_of_cases else "standard"
self.workflow = self.get_workflow(library["workflow"], type_of_cases)

type_of_cases = type_of_cases if type_of_cases else "standard"
if types_for_generation is None:
types_for_generation = ["standard"]
if not isinstance(types_for_generation, list):
types_for_generation = [types_for_generation]
neutral_lib = self.form_neutral_lib(self.init_lib)
self.spread_requirements(neutral_lib)
crop_types(neutral_lib, types_for_generation)
self.types_for_generation = types_for_generation

self.combinations = self.find_combinations(neutral_lib, type_of_cases)
self.form_combinations()

assert self.combinations, "No combinations for tests" #TODO deep logging needed
assert self.combinations, "No combinations for tests" # TODO deep logging needed

if amount is not None:
self.combinations = extend_dict(self.combinations, amount)
def spread_requirements(self, neutral_lib):
for field, modes in neutral_lib.items():
for mode, case in modes.items():
self.init_lib[field][mode].requirements = case.requirements

def find_combinations(self, neutral_lib, type_of_cases):
all_combinations = {}
self.extend_cases(amount)

def extend_cases(self, amount):

workflow = copy.deepcopy(self.workflow)
if "ST_COMBINE" in [process.name for process in current_workflow(workflow, True)]:
combi_graph = MDG(self.init_lib, self.types_for_generation)
combinations = list(self.combinations.keys())
random.shuffle(combinations)
for combination_name in combinations:
if not combi_graph.can_combine(self.combinations[combination_name].main_case):
del self.combinations[combination_name]

dict_keys = list(self.combinations.keys())
random.shuffle(dict_keys)
key_count = len(dict_keys)

if key_count == amount:
return
elif key_count > amount:
for i in range(key_count - amount):
del self.combinations[dict_keys[i]]
return

i = 0
added_key_count = 0
shuffle_point = (amount // key_count) * key_count

while added_key_count < amount:
if i % key_count == 0 and i == shuffle_point:
random.shuffle(dict_keys)
key = dict_keys[i % key_count]
value = self.combinations[key]

if i < key_count:
self.combinations[key] = copy.deepcopy(value)
else:
extended_key = f"{key}[{i // key_count}]"
self.combinations[extended_key] = copy.deepcopy(value)

i += 1
added_key_count += 1

def form_combinations(self):
self.combinations = {}
for field_name, cases in self.init_lib.items():
for field_mode, case in cases.items():
if case.type_of_case == type_of_cases and can_combine(neutral_lib, case):
current_combination = Combination(case, self.workflow, neutral_lib,
self.template, self.tools, self.logger, self.generator_id)
all_combinations.update({case.case_name: current_combination})
return all_combinations
if case.type_of_case == self.type_of_cases:
current_combination = Combination(case, self.workflow, self.init_lib,
self.template, self.tools, self.logger, self.generator_id,
self.types_for_generation)
self.combinations.update({case.case_name: current_combination})

def run(self):
workflow = self.workflow.pop(0) if isinstance(self.workflow, list) else self.workflow
combinations = list(self.combinations.values())

for current_step in workflow:
while step_not_done(current_step.name, combinations):
for combination in combinations:
if combination.step_done != current_step.name:
if self.logger:
self.logger.start_step(self.generator_id, current_step.name)
try:
current_step.activate(combination)
except Exception as e:
combination.step_done = "STOP"
if self.logger:
temp_exep = f"An exception occurred: {type(e).__name__}. "
temp_exep += f"Error message: {str(e)}. "
traceback_list = traceback.extract_tb(e.__traceback__)
if traceback_list:
last_traceback = traceback_list[-1]
file_name = last_traceback.filename
line_number = last_traceback.lineno
temp_exep += f"Occurred at: {file_name}:{line_number}. "
self.logger.end_step(self.generator_id, temp_exep)
else:
raise e
else:
if self.logger:
self.logger.end_step(self.generator_id)

combinations_names = list(self.combinations.keys())
for combination_name in combinations_names:
self.combinations[combination_name].run()
if isinstance(self.combinations[combination_name].step_done, type(CombinatoricsError())):
del self.combinations[combination_name]

def run_one(self):
combinations = list(self.combinations.keys())
random.shuffle(combinations)
for combination_name in combinations:
combinations[combination_name].run()
if combinations[combination_name].step_done != CombinatoricsError():
return combinations[combination_name]

def get_one(self):
return self.combinations[random.choice(list(self.combinations.keys()))]
workflow = copy.deepcopy(self.workflow)
if "ST_COMBINE" in [process.name for process in current_workflow(workflow, True)]:
combi_graph = MDG(self.init_lib, self.types_for_generation)
combinations = list(self.combinations.keys())
random.shuffle(combinations)
for combination in combinations:
if combi_graph.can_combine(self.combinations[combination].main_case):
return self.combinations[combination]
else:
combinations = list(self.combinations.keys())
return self.combinations[random.choice(combinations)]

def form_modes_for_gen(self, possible_modes):
# todo def any_passed
@staticmethod
def form_modes_for_gen(possible_modes):
modes_for_gen = copy.deepcopy(possible_modes)

if modes_for_gen is not None:
Expand All @@ -213,12 +187,6 @@ def form_init_lib(self, library):
if self.modes_for_gen is not None:
if field_name in self.modes_for_gen.keys() and field_mode not in self.modes_for_gen[field_name]:
init_lib[field_name][field_mode].type_of_case = "OFF"
elif requirements := init_lib[field_name][field_mode].requirements:
for rec_field, rec_modes in requirements.items():
if rec_field in self.modes_for_gen.keys() and not rec_modes & set(
self.modes_for_gen[rec_field]):
init_lib[field_name][field_mode].type_of_case = "OFF"
break
return init_lib

def get_workflow(self, workflow, type_of_cases):
Expand All @@ -237,32 +205,3 @@ def dell_fields(self, possible_fields, banned_fields):
for field in banned_fields:
del self.init_lib[field]

def form_neutral_lib(self, init_lib):
neutral_lib = form_template(init_lib)

for field, modes in init_lib.items():
for mode in modes:
if init_lib[field][mode].requirements:
for req_unit, req_modes in init_lib[field][mode].requirements.items():
if req_unit in neutral_lib.keys() and mode in neutral_lib[field].keys():
neutral_lib[field][mode].requirements[req_unit] = req_modes & \
neutral_lib[field][mode].requirements[
req_unit]
if not neutral_lib[field][mode].requirements[req_unit]:
del neutral_lib[field][mode]
if self.logger:
self.logger.add_log(self.generator_id,
f"Mode: {mode} in field: {field}: Was deleted because will never use in generation")
modes_for_hunt = set(neutral_lib[req_unit].keys()) - req_modes
for target_mode in modes_for_hunt:
neutral_lib[req_unit][target_mode].requirements[field] = \
neutral_lib[req_unit][target_mode].requirements[field] - set(mode)
if not neutral_lib[req_unit][target_mode].requirements[field]:
del neutral_lib[req_unit][target_mode]
if self.logger:
self.logger.add_log(self.generator_id,
f"Mode: {target_mode} in field: {req_unit}: Was deleted because will never use in generation")



return neutral_lib
Loading

0 comments on commit 9d2b6dd

Please sign in to comment.