diff --git a/Docs b/Docs index 80c6d8bc..0187712c 160000 --- a/Docs +++ b/Docs @@ -1 +1 @@ -Subproject commit 80c6d8bce1c33ab64caf754bf4309cceebc15b3e +Subproject commit 0187712c053ab9419defe6c4fac49a8b069a365e diff --git a/Tasks b/Tasks index e5bb5c4f..0187712c 160000 --- a/Tasks +++ b/Tasks @@ -1 +1 @@ -Subproject commit e5bb5c4fbb828678f5d1ec6cdec6168c8c2c6751 +Subproject commit 0187712c053ab9419defe6c4fac49a8b069a365e diff --git a/domiknows/graph/dataNode.py b/domiknows/graph/dataNode.py index b84c77c1..04b45077 100644 --- a/domiknows/graph/dataNode.py +++ b/domiknows/graph/dataNode.py @@ -4,15 +4,16 @@ import re from itertools import count -from .dataNodeConfig import dnConfig +from .dataNodeConfig import dnConfig -from ordered_set import OrderedSet +from ordered_set import OrderedSet from domiknows import getRegrTimer_logger, getProductionModeStatus from domiknows.solver import ilpOntSolverFactory from domiknows.utils import getDnSkeletonMode from domiknows.graph.relation import Contains + import logging from logging.handlers import RotatingFileHandler from .property import Property @@ -42,7 +43,7 @@ logBackupCount = dnConfig['log_backupCount'] if 'log_fileMode' in dnConfig: logFileMode = dnConfig['log_fileMode'] - + # Create file handler and set level to info import pathlib pathlib.Path("logs").mkdir(parents=True, exist_ok=True) @@ -61,7 +62,7 @@ _DataNode__Logger.addHandler(ch) # Don't propagate _DataNode__Logger.propagate = False - + # --- Create loggers _DataNodeBuilder__Logger = logging.getLogger("dataNodeBuilder") _DataNodeBuilder__Logger.setLevel(logLevel) @@ -75,7 +76,7 @@ class DataNode: """ Represents a single data instance in a graph with relation links to other data nodes. - + Attributes: - myBuilder (DatanodeBuilder): DatanodeBuilder used to construct this datanode. - instanceID (various): The data instance ID (e.g., paragraph number, sentence number). @@ -90,7 +91,7 @@ class DataNode: - myLoggerTime (Logger): Logger for time measurement. """ _ids = count(1) - + def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, ontologyNode = None, graph = None, relationLinks = {}, attributes = {}): """Initialize a DataNode instance. @@ -102,7 +103,7 @@ def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, on graph (Graph): Graph to which the DataNode belongs. relationLinks (dict): Dictionary mapping relation name to RelationLinks. attributes (dict): Dictionary with node's attributes. - + Attributes: myBuilder (DatanodeBuilder): DatanodeBuilder used to construct this datanode. instanceID (various): The data instance ID. @@ -121,42 +122,42 @@ def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, on self.instanceID = instanceID # The data instance id (e.g. paragraph number, sentence number, phrase number, image number, etc.) self.instanceValue = instanceValue # Optional value of the instance (e.g. paragraph text, sentence text, phrase text, image bitmap, etc.) self.ontologyNode = ontologyNode # Reference to the node in the ontology graph (e.g. Concept) which is the type of this instance (e.g. paragraph, sentence, phrase, etc.) - + if ontologyNode is not None: self.graph = self.ontologyNode.sup if graph is not None: self.graph = graph - + if relationLinks: self.relationLinks = relationLinks # Dictionary mapping relation name to RelationLinks else: self.relationLinks = {} - + self.impactLinks = {} # Dictionary with dataNodes impacting this dataNode by having it as a subject of its relation - + if attributes: self.attributes = attributes # Dictionary with node's attributes else: self.attributes = {} - + self.current_device = 'cpu' if torch.cuda.is_available(): if torch.cuda.device_count() > 1: self.current_device = 'cuda:1' else: self.current_device = 'cuda' if torch.cuda.is_available() else 'cpu' - + self.gurobiModel = None - + self.myLoggerTime = getRegrTimer_logger() - + class DataNodeError(Exception): """Exception raised for DataNode-related errors.""" pass - + def __str__(self): """Return the string representation of the DataNode object. - + Returns: str: String representation of the instance. """ @@ -164,10 +165,10 @@ def __str__(self): return self.instanceValue else: return '{} {}'.format(self.ontologyNode.name, self.instanceID) - + def __repr__(self): """Return the unambiguous string representation of the DataNode object. - + Returns: str: Unambiguous string representation of the instance. """ @@ -175,60 +176,60 @@ def __repr__(self): return self.instanceValue else: return '{} {}'.format(self.ontologyNode.name, self.instanceID) - + def __reprDeep__(self, strRep=""): """Return the deep string representation of the DataNode object including its relations. - + Args: strRep (str): Accumulated string representation. - + Returns: str: Deep string representation of the instance. """ rel = [*self.getRelationLinks().keys()] if 'contains' in rel: rel.remove('contains') - + relString = None if len(rel) > 0: relString = ' (' + str(rel) + ')' - + if relString: strRep += self.ontologyNode.name + str(rel) else: - strRep += self.ontologyNode.name - + strRep += self.ontologyNode.name + childrenDns = {} for cDn in self.getChildDataNodes(): if cDn.getOntologyNode().name not in childrenDns: childrenDns[cDn.getOntologyNode().name] = [] childrenDns[cDn.getOntologyNode().name].append(cDn) - + strRep += '\n' for childType in childrenDns: strRep += '\t' + childrenDns[childType][0].__repr__(strRep) - + return strRep - + def getInstanceID(self): """Get the instance ID of the DataNode object. - + Returns: various: Instance ID of the DataNode object. """ return self.instanceID - + def getInstanceValue(self): """Get the instance value of the DataNode object. - + Returns: various: Instance value of the DataNode object. """ return self.instanceValue - + def getOntologyNode(self): """Get the ontology node related to the DataNode object. - + Returns: Node: Ontology node related to the DataNode object. """ @@ -237,18 +238,18 @@ def getOntologyNode(self): def visualize(self, filename: str, inference_mode="ILP", include_legend=False, open_image=False): """Visualize the current DataNode instance and its attributes. - This method creates a graph visualization using the Graphviz library. The + This method creates a graph visualization using the Graphviz library. The visualization includes attributes and relationships. - + Args: filename (str): The name of the file where the Graphviz output will be stored. inference_mode (str, optional): The mode used for inference ("ILP" by default). include_legend (bool, optional): Whether or not to include a legend in the visualization. open_image (bool, optional): Whether or not to automatically open the generated image. - + Raises: Exception: If the specified inference_mode is not found in the DataNode. - + """ if include_legend: # Build Legend subgraph @@ -303,7 +304,7 @@ def visualize(self, filename: str, inference_mode="ILP", include_legend=False, o else: # Normal nodes g.attr('node', shape='rectangle') - + # Format attribute attr_str = str(attribute) if isinstance(attribute, torch.Tensor): @@ -318,59 +319,59 @@ def visualize(self, filename: str, inference_mode="ILP", include_legend=False, o main_graph.subgraph(g) main_graph.render(filename, format='png', view=open_image) - + # --- Attributes methods - + def getAttributes(self): """Get all attributes of the DataNode. - + Returns: dict: Dictionary containing all attributes of the DataNode. """ return self.attributes - + def hasAttribute(self, key): """Check if the DataNode has a specific attribute. - + Args: key (str): The key of the attribute to check for. - + Returns: bool: True if the attribute exists, False otherwise. """ # Your code for checking attribute existence # ... return False - + def getAttribute(self, *keys): """Retrieve a specific attribute using a key or a sequence of keys. - - The method accepts multiple keys in the form of positional arguments, - combining them to identify the attribute to retrieve. - + + The method accepts multiple keys in the form of positional arguments, + combining them to identify the attribute to retrieve. + Args: *keys (str or tuple or Concept): The key(s) to identify the attribute. - + Returns: object: The value of the attribute if it exists, or None otherwise. """ key = "" keyBis = "" index = None - + conceptFound = False for _, kConcept in enumerate(keys): if key != "": key = key + "/" keyBis = keyBis + "/" - + # Handle different way of representing concept in the key list if isinstance(kConcept, str): # Concept name conceptForK = None if not conceptFound: conceptForK = self.findConcept(kConcept) # Find concept - - if not conceptFound and conceptForK is not None: + + if not conceptFound and conceptForK is not None: conceptFound = True if isinstance(conceptForK, tuple): key = key + '<' + conceptForK[0].name +'>' @@ -390,8 +391,8 @@ def getAttribute(self, *keys): conceptFound = True key = key + '<' + kConcept.name +'>' keyBis = keyBis + kConcept.name - - # Use key and keyBis to get the dn attribute + + # Use key and keyBis to get the dn attribute if key in self.attributes: if index is None: return self.attributes[key] @@ -406,7 +407,7 @@ def getAttribute(self, *keys): if "rootDataNode" in self.attributes: rootDataNode = self.attributes["rootDataNode"] keyInVariableSet = self.ontologyNode.name + "/" + key - + if "variableSet" in rootDataNode.attributes: if keyInVariableSet in rootDataNode.attributes["variableSet"]: return rootDataNode.attributes["variableSet"][keyInVariableSet][self.instanceID] @@ -417,11 +418,11 @@ def getAttribute(self, *keys): return self.attributes["variableSet"][key] elif "propertySet" in self.attributes and key in self.attributes["propertySet"]: return self.attributes["propertySet"][key] - - return None - + + return None + # --- Relation Link methods - + def getRelationLinks(self, relationName = None, conceptName = None): """Retrieve relation links for a given relation and concept name. @@ -429,11 +430,11 @@ def getRelationLinks(self, relationName = None, conceptName = None): the concept name. It supports the flexibility to look up based on either just a relation name, just a concept name, or both. If neither is given, it returns all relation links. - + Args: relationName (str or None): The name of the relation to filter by. If None, no filtering is done based on the relation name. conceptName (str or None): The name of the concept to filter by. If None, no filtering is done based on the concept name. - + Returns: list: A list of DataNodes that match the given relation and concept names, or an empty list if no matches are found. """ @@ -442,176 +443,176 @@ def getRelationLinks(self, relationName = None, conceptName = None): return self.relationLinks else: conceptCN = [] - + for r in self.relationLinks: for dn in self.relationLinks[r]: if dn.ontologyNode.name == conceptName: conceptCN.append(dn) - + return conceptCN - + if not isinstance(relationName, str): relationName = relationName.name - + if relationName in self.relationLinks: relDNs = self.relationLinks[relationName] - + if conceptName is None: return relDNs else: conceptCN = [] - + if not isinstance(conceptName, str): conceptName = conceptName.name - + for dn in relDNs: if dn.ontologyNode.name == conceptName: conceptCN.append(dn) - + return conceptCN else: return [] - + def addRelationLink(self, relationName, dn): """Add a relation link between the current DataNode and another DataNode. - + This method establishes a relation link from the current DataNode to another DataNode ('dn') under a given relation name. It also updates the impactLinks for the target DataNode. - + Args: relationName (str): The name of the relation to add. dn (DataNode): The target DataNode to link to. - + Returns: None """ if relationName is None: return - + if relationName not in self.relationLinks: self.relationLinks[relationName] = [] - + if dn in self.relationLinks[relationName]: - return - + return + self.relationLinks[relationName].append(dn) - + # Impact if relationName not in dn.impactLinks: dn.impactLinks[relationName] = [] - + if self not in dn.impactLinks[relationName]: dn.impactLinks[relationName].append(self) def removeRelationLink(self, relationName, dn): """Remove a relation link between the current DataNode and another DataNode. - - This method removes a relation link from the current DataNode to another + + This method removes a relation link from the current DataNode to another DataNode ('dn') under a given relation name. It also updates the impactLinks for the target DataNode. - + Args: relationName (str): The name of the relation to remove. dn (DataNode): The target DataNode to unlink from. - + Returns: None """ if relationName is None: return - + if relationName not in self.relationLinks: return - + self.relationLinks[relationName].remove(dn) - + # Impact if relationName in dn.impactLinks: dn.impactLinks[relationName].remove(self) - + def getLinks(self, relationName = None, conceptName = None): """Get links associated with the DataNode based on the relation and concept names. - + This method retrieves the DataNodes linked to the current DataNode through - either relation links or impact links. You can filter these links based on + either relation links or impact links. You can filter these links based on the name of the relation or the name of the concept (ontology node). - + Args: - relationName (str, optional): The name of the relation to filter by. + relationName (str, optional): The name of the relation to filter by. Defaults to None. conceptName (str, optional): The name of the ontology node (concept) to filter by. Defaults to None. - + Returns: - dict or list: A dictionary containing the DataNodes linked through relation or - impact links. If relationName or conceptName is provided, + dict or list: A dictionary containing the DataNodes linked through relation or + impact links. If relationName or conceptName is provided, returns a list of DataNodes that match the criteria. """ keys = self.relationLinks.keys() | self.impactLinks.keys() - + links = {} for k in keys: if k not in self.relationLinks: links[k] = self.impactLinks[k] continue - + if k not in self.impactLinks: links[k] = self.relationLinks[k] continue - + links[k] = self.impactLinks[k] + self.relationLinks[k] - + if relationName is None: if conceptName is None: return links else: conceptCN = [] - + for r in links: for dn in links[r]: if dn.ontologyNode.name == conceptName: conceptCN.append(dn) - + return conceptCN - + if not isinstance(relationName, str): relationName = relationName.name - + if relationName in links: relDNs = links[relationName] - + if conceptName is None: return relDNs else: conceptCN = [] - + if not isinstance(conceptName, str): conceptName = conceptName.name - + for dn in relDNs: if dn.ontologyNode.name == conceptName: conceptCN.append(dn) - + return conceptCN else: return [] # --- Contains (children) relation methods - + def getChildDataNodes(self, conceptName = None): """Retrieve child DataNodes based on the concept name. - + Args: conceptName (str, optional): The name of the concept to filter the child DataNodes. Defaults to None. - + Returns: list: A list of child DataNodes that match the given concept name. Returns None if there are no child DataNodes. """ containsDNs = self.getRelationLinks('contains') - + if conceptName is None: return containsDNs @@ -619,7 +620,7 @@ def getChildDataNodes(self, conceptName = None): return None conceptCN = [] - + for dn in containsDNs: if isinstance(conceptName, str): if dn.ontologyNode.name == conceptName: @@ -627,76 +628,76 @@ def getChildDataNodes(self, conceptName = None): else: if dn.ontologyNode == conceptName: conceptCN.append(dn) - + return conceptCN - + def addChildDataNode(self, dn): """Add a child DataNode to the current DataNode. - + Args: dn (DataNode): The DataNode to be added as a child. """ relationName = 'contains' - + if (relationName in self.relationLinks) and (dn in self.relationLinks[relationName]): return - + self.addRelationLink(relationName, dn) def removeChildDataNode(self, dn): """Remove a child DataNode from the current DataNode. - + Args: dn (DataNode): The DataNode to be removed. """ relationName = 'contains' self.removeRelationLink(relationName, dn) - + def resetChildDataNode(self): """Reset all child DataNodes from the current DataNode. """ relationName = 'contains' self.relationLinks[relationName] = [] - + # --- Equality methods - + def getEqualTo(self, equalName="equalTo", conceptName=None): """Retrieve DataNodes that are equal to the current DataNode. - + Args: equalName (str, optional): The name of the relation for equality. Defaults to "equalTo". conceptName (str, optional): The name of the concept to filter the DataNodes. Defaults to None. - + Returns: list: A list of DataNodes that are considered equal to the current DataNode. """ if conceptName: dns = self.getRelationLinks(relationName=equalName) - + filteredDns = [] for dn in dns: if dn.getOntologyNode().name == conceptName: filteredDns.append(dn) - + return filteredDns else: return self.getRelationLinks(relationName=equalName) - + def addEqualTo(self, equalDn, equalName="equalTo"): """Add a DataNode that is considered equal to the current DataNode. - + Args: equalDn (DataNode): The DataNode to be added. equalName (str, optional): The name of the relation for equality. Defaults to "equalTo". """ self.addRelationLink(equalName, equalDn) - + def removeEqualTo(self, equalDn, equalName="equalTo"): """Remove a DataNode that is considered equal to the current DataNode. - + Args: equalDn (DataNode): The DataNode to be removed. equalName (str, optional): The name of the relation for equality. Defaults to "equalTo". @@ -705,21 +706,21 @@ def removeEqualTo(self, equalDn, equalName="equalTo"): # --- Query methods - + def findConceptsAndRelations(self, dn, conceptsAndRelations = None, visitedDns = None): """Recursively search for concepts and relations in the data graph starting from a given dataNode (dn). - This method will traverse through linked dataNodes to find concepts and relations. If 'variableSet' + This method will traverse through linked dataNodes to find concepts and relations. If 'variableSet' is present in the attributes, it will return those concepts directly. - + Args: dn (DataNode): The dataNode from which to start the search. conceptsAndRelations (set, optional): A set to store found concepts and relations. Defaults to None. visitedDns (set, optional): A set to keep track of visited dataNodes to prevent cycles. Defaults to None. - + Returns: set: A set containing the names of all found concepts and relations. - + """ if 'variableSet' in self.attributes: conceptsAndRelations = set() @@ -727,32 +728,32 @@ def findConceptsAndRelations(self, dn, conceptsAndRelations = None, visitedDns = if "/label" in key: continue conceptsAndRelations.add(key[key.index('<')+1:key.index('>')]) - + return conceptsAndRelations - else: + else: if conceptsAndRelations is None: conceptsAndRelations = set() if visitedDns is None: visitedDns = set() - + # Find concepts in dataNode - concept are in attributes from learning sensors for att in dn.attributes: - if att[0] == '<' and att[-1] == '>': + if att[0] == '<' and att[-1] == '>': if att[1:-1] not in conceptsAndRelations: conceptsAndRelations.add(att[1:-1]) _DataNode__Logger.info('Found concept %s in dataNode %s'%(att[1:-1],dn)) - - # Recursively find concepts and relations in linked dataNodes + + # Recursively find concepts and relations in linked dataNodes links = dn.getLinks() if links: for link in links: for lDn in links[link]: if lDn in visitedDns: continue - + visitedDns.add(lDn) self.findConceptsAndRelations(lDn, conceptsAndRelations = conceptsAndRelations, visitedDns = visitedDns) - + return conceptsAndRelations def findConceptsNamesInDatanodes(self, dns = None, conceptNames = None, relationNames = None): @@ -762,7 +763,7 @@ def findConceptsNamesInDatanodes(self, dns = None, conceptNames = None, relation dns (list, optional): A list of DataNodes to be searched. Defaults to None. conceptNames (set, optional): A set to store the names of concepts found. Defaults to None. relationNames (set, optional): A set to store the names of relations found. Defaults to None. - + Returns: tuple: A tuple containing two sets: (conceptNames, relationNames). """ @@ -773,55 +774,55 @@ def findConceptsNamesInDatanodes(self, dns = None, conceptNames = None, relation if dns is None: dns = [self] - + for dn in dns: conceptNames.add(dn.getOntologyNode().name) for relName, _ in dn.getRelationLinks().items(): if relName != 'contains': relationNames.add(relName) - + self.findConceptsNamesInDatanodes(dns=dn.getChildDataNodes(), conceptNames = conceptNames, relationNames = relationNames) - + return conceptNames, relationNames - + def findRootConceptOrRelation(self, relationConcept, usedGraph = None): """Finds the root concept or relation of a given relation or concept. Args: relationConcept (str or Object): The relation or concept to find the root for. usedGraph (Object, optional): The ontology graph where the relation or concept exists. Defaults to None. - + Returns: Object or str: The root concept or relation. """ if usedGraph is None: usedGraph = self.ontologyNode.getOntologyGraph() - + if isinstance(relationConcept, str): _relationConcepts = self.findConcept(relationConcept) - + if _relationConcepts: relationConcept = _relationConcepts[0] else: - return relationConcept + return relationConcept # Does this concept or relation has parent (through _isA) - + if isinstance(relationConcept, tuple): relationConcept = relationConcept[0] - + try: isAs = relationConcept.is_a() except (AttributeError, TypeError): isAs = [] - + for _isA in isAs: _relationConcept = _isA.dst - + return self.findRootConceptOrRelation(_relationConcept, usedGraph) - + # If the provided concept or relation is root (has not parents) - return relationConcept + return relationConcept def __testDataNode(self, dn, test): """Tests a DataNode based on various types of conditions. @@ -829,40 +830,40 @@ def __testDataNode(self, dn, test): Args: dn (Object): The DataNode to be tested. test (tuple/list/str/int): The conditions to test the DataNode. - + Returns: bool: True if the DataNode satisfies the conditions, False otherwise. """ if test is None: return False - + if isinstance(test, tuple) or isinstance(test, list): # tuple with at least three elements (concept, key elements, value of attribute) _test = [] for t in test: if isinstance(t, tuple): r = self.__testDataNode(dn, t) - + if not r: return False else: _test.append(t) - + if len(_test) == 0: return True else: test = _test - - if len(test) >= 3: + + if len(test) >= 3: if isinstance(test[0], str): if dn.getOntologyNode().name != test[0]: return False else: if dn.getOntologyNode().name != test[0].name: return False - + keys = test[1:-1] v = dn.getAttribute(*keys) - + last = test[-1] if v == last: return True @@ -870,47 +871,48 @@ def __testDataNode(self, dn, test): return False else: test = [test] - + for i, t in enumerate(test): if isinstance(t, int): if dn.getInstanceID() == t: return True else: return False - + if t == "instanceID" and i < len(test) - 1: if dn.getInstanceID() == test[i+1]: return True else: return False - + if not isinstance(t, str): t = t.name - + if t == dn.getOntologyNode().name: return True else: return False - + def getDnsForRelation(self, rel): """Get DataNodes associated with a given relation. The method first finds the root concept or relation for the given 'rel'. Depending on what it finds, it returns the corresponding DataNodes. - + Args: rel (str/Object): The relation or concept for which DataNodes are needed. - + Returns: list: A list of DataNodes corresponding to the relation, or [None] if not found. """ relRoot = self.findRootConceptOrRelation(rel) - + if relRoot is None: return [None] if isinstance(relRoot, Contains): + relRoot = "contains" #TODO fix the generic name for multiple constaints if not isinstance(relRoot, str): @@ -918,7 +920,7 @@ def getDnsForRelation(self, rel): if relRoot.endswith(".reversed"): relRoot = relRoot[:-len(".reversed")] - if relRoot in self.impactLinks: + if relRoot in self.impactLinks: return self.impactLinks[relRoot] else: return [None] @@ -926,7 +928,7 @@ def getDnsForRelation(self, rel): return self.relationLinks[relRoot] else: return [None] - + def findDatanodes(self, dns = None, select = None, indexes = None, visitedDns = None, depth = 0): """Find and return DataNodes based on the given query conditions. @@ -936,57 +938,57 @@ def findDatanodes(self, dns = None, select = None, indexes = None, visitedDns = indexes (dict): Optional query filtering. visitedDns (OrderedSet): Keeps track of already visited DataNodes. depth (int): Depth of the recursive call. - + Returns: list: List of DataNodes that satisfy the query condition. """ # If no DataNodes provided use self if not depth and dns is None: dns = [self] - + returnDns = [] - + # If empty list of provided DataNodes then return - it is a recursive call with empty list if dns is None or len(dns) == 0: return returnDns - + # No select provided - query not defined - return if select is None: if depth == 0 and not returnDns: _DataNode__Logger.warning('Not found any DataNode - no value for the select part of query provided') - + return returnDns - - # Check each provided DataNode if it satisfy the select part of the query + + # Check each provided DataNode if it satisfy the select part of the query for dn in dns: # Test current DataNote against the query if self.__testDataNode(dn, select): if dn not in returnDns: - returnDns.append(dn) - + returnDns.append(dn) + if not visitedDns: visitedDns = OrderedSet() - + visitedDns.add(dn) - + # Call recursively newDepth = depth + 1 for dn in dns: # Visit DataNodes in links - for r, rValue in dn.getLinks().items(): - + for r, rValue in dn.getLinks().items(): + # Check if the nodes already visited dnsToVisit = OrderedSet() for rDn in rValue: if rDn not in visitedDns: dnsToVisit.add(rDn) - + if not dnsToVisit: continue - + # Visit DataNodes in the current relation currentRelationDns = self.findDatanodes(dnsToVisit, select = select, indexes = indexes, visitedDns = visitedDns, depth = newDepth) - + if currentRelationDns is not None: for currentRDn in currentRelationDns: if currentRDn not in returnDns: @@ -994,21 +996,21 @@ def findDatanodes(self, dns = None, select = None, indexes = None, visitedDns = if depth: # Finish recursion return returnDns - + # If index provided in query then filter the found results for the select part of query through the index part of query if (indexes != None): currentReturnDns = [] # Will contain results from returnDns satisfying the index - + for dn in returnDns: - fit = True + fit = True for indexName, indexValue in indexes.items(): - + relDns = dn.getDnsForRelation(indexName) - + if relDns is None or len(relDns) == 0 or relDns[0] is None: fit = False break - + found = False for rDn in relDns: if isinstance(indexValue, tuple): @@ -1016,48 +1018,48 @@ def findDatanodes(self, dns = None, select = None, indexes = None, visitedDns = for t in indexValue: if isinstance(t, tuple): r = self.__testDataNode(rDn, t) - + if r: found = True break else: _test.append(t) - + if len(_test) == 0: continue else: indexValue = _test - + if self.__testDataNode(rDn, indexValue): found = True break - + if not found: fit = False break - + if fit: if dn not in currentReturnDns: currentReturnDns.append(dn) - + returnDns = currentReturnDns - + # If not fund any results if depth == 0 and not returnDns: _DataNode__Logger.debug('Not found any DataNode for - %s -'%(select)) - + # Sort results according to their ids if returnDns: returnDnsNotSorted = OrderedDict() for dn in returnDns: returnDnsNotSorted[dn.getInstanceID()] = dn - + returnDnsSorted = OrderedDict(sorted(returnDnsNotSorted.items())) - + returnDns = [*returnDnsSorted.values()] - + return returnDns - + # Get root of the dataNode def getRootDataNode(self): """Get the root DataNode. @@ -1069,59 +1071,59 @@ def getRootDataNode(self): return self.impactLinks["contains"][0].getRootDataNode() else: return self - + # Keeps hashMap of concept name queries in findConcept to results conceptsMap = {} - + def findConcept(self, conceptName, usedGraph = None): """Find concept based on the name in the ontology graph. Args: conceptName (str or Concept): The name of the concept to find. usedGraph (object): The ontology graph to search within. - + Returns: tuple or None: A tuple containing details about the found concept or None if not found. """ if '<' in conceptName: conceptName = conceptName[1:-1] - + if not usedGraph: usedGraph = self.ontologyNode.getOntologyGraph() - + if usedGraph not in self.conceptsMap: self.conceptsMap[usedGraph] = {} - + usedGraphConceptsMap = self.conceptsMap[usedGraph] - + if isinstance(conceptName, Concept): conceptName = conceptName.name() - + if conceptName in usedGraphConceptsMap: return usedGraphConceptsMap[conceptName] - + subGraph_keys = [key for key in usedGraph._objs] for subGraphKey in subGraph_keys: subGraph = usedGraph._objs[subGraphKey] - + for conceptNameItem in subGraph.concepts: if conceptName == conceptNameItem: concept = subGraph.concepts[conceptNameItem] - + usedGraphConceptsMap[conceptName] = (concept, concept.name, None, 1) return usedGraphConceptsMap[conceptName] - + elif isinstance(subGraph.concepts[conceptNameItem], EnumConcept): vlen = len(subGraph.concepts[conceptNameItem].enum) - + if conceptName in subGraph.concepts[conceptNameItem].enum: concept = subGraph.concepts[conceptNameItem] - + usedGraphConceptsMap[conceptName] = (concept, conceptName, subGraph.concepts[conceptNameItem].get_index(conceptName), vlen) return usedGraphConceptsMap[conceptName] usedGraphConceptsMap[conceptName] = None - + return usedGraphConceptsMap[conceptName] def isRelation(self, conceptRelation, usedGraph = None): @@ -1130,134 +1132,134 @@ def isRelation(self, conceptRelation, usedGraph = None): Args: conceptRelation (str or Concept): The concept or relation to check. usedGraph (object, optional): The ontology graph to use. Defaults to the one associated with self. - + Returns: bool: True if the concept is a relation, otherwise False. """ if usedGraph is None: usedGraph = self.ontologyNode.getOntologyGraph() - + if isinstance(conceptRelation, str): conceptRelation = self.findConcept(conceptRelation) - + if conceptRelation == None: return False - + conceptRelation = conceptRelation[0] - + from domiknows.graph.relation import Relation if isinstance(conceptRelation, Relation): return True - - if len(conceptRelation.has_a()) > 0: + + if len(conceptRelation.has_a()) > 0: return True - + for _isA in conceptRelation.is_a(): _conceptRelation = _isA.dst - + if self.__isRelation(_conceptRelation, usedGraph): return True - - return False - + + return False + def getRelationAttrNames(self, conceptRelation, usedGraph = None): """Get attribute names for a given relation or concept that is a relation. Args: conceptRelation (Concept): The concept or relation to check for attributes. usedGraph (object, optional): The ontology graph to use. Defaults to the ontology graph associated with self. - + Returns: OrderedDict or None: An ordered dictionary of attribute names and their corresponding concepts, or None if no attributes found. """ if usedGraph is None: usedGraph = self.ontologyNode.getOntologyGraph() - - if len(conceptRelation.has_a()) > 0: + + if len(conceptRelation.has_a()) > 0: relationAttrs = OrderedDict() - for _, rel in enumerate(conceptRelation.has_a()): - dstName = rel.dst.name + for _, rel in enumerate(conceptRelation.has_a()): + dstName = rel.dst.name relationAttr = self.findConcept(dstName, usedGraph)[0] - + relationAttrs[rel.name] = relationAttr - + return relationAttrs - + for _isA in conceptRelation.is_a(): _conceptRelation = _isA.dst - + resultForCurrent = self.__getRelationAttrNames(_conceptRelation, usedGraph) - + if bool(resultForCurrent): return resultForCurrent - - return None + + return None # cache collectedConceptsAndRelations = None - + def collectConceptsAndRelations(self, conceptsAndRelations = None): """Collect all the concepts and relations from the data graph and transform them into tuple form. Args: conceptsAndRelations (set, optional): A set to store the found concepts and relations. Defaults to None. - + Returns: list: A list of tuples, each representing a concept or relation with additional information. """ if conceptsAndRelations is None: conceptsAndRelations = set() - + if self.collectedConceptsAndRelations: return self.collectedConceptsAndRelations - + # Search the graph starting from self for concepts and relations - candR = self.findConceptsAndRelations(self) + candR = self.findConceptsAndRelations(self) self.rootConcepts = [] - + returnCandR = [] - + # Process founded concepts - translate them to tuple form with more information needed for logical constraints and metrics for c in candR: _concept = self.findConcept(c)[0] - + if _concept is None: continue - + if isinstance(_concept, tuple): _concept = _concept[0] - + # Check if this is multiclass concept if isinstance(_concept, EnumConcept): self.rootConcepts.append((_concept, len(_concept.enum))) - + for i, a in enumerate(_concept.enum): - + if conceptsAndRelations and a not in conceptsAndRelations: # continue pass - + returnCandR.append((_concept, a, i, len(_concept.enum))) # Create tuple representation for multiclass concept else: self.rootConcepts.append((_concept, 1)) if conceptsAndRelations and c not in conceptsAndRelations and _concept not in conceptsAndRelations: continue - + returnCandR.append((_concept, _concept.name, None, 1)) # Create tuple representation for binary concept - + self.collectedConceptsAndRelations = returnCandR return self.collectedConceptsAndRelations - + def getILPSolver(self, conceptsRelations = None): """Get the ILP Solver instance based on the given concepts and relations. Args: conceptsRelations (list, optional): A list of concepts and relations to be considered. Defaults to None. - + Returns: tuple: An instance of ILP Solver and the list of processed concepts and relations. - + Raises: DataNodeError: If the ILP Solver is not initialized. """ @@ -1265,32 +1267,32 @@ def getILPSolver(self, conceptsRelations = None): conceptsRelations = [] _conceptsRelations = [] - + # Get ontology graphs and then ilpOntsolver myOntologyGraphs = {self.ontologyNode.getOntologyGraph()} - + for currentConceptOrRelation in conceptsRelations: if isinstance(currentConceptOrRelation, str): currentConceptOrRelation = self.findConcept(currentConceptOrRelation) - + _conceptsRelations.append(currentConceptOrRelation) - + if isinstance(currentConceptOrRelation, tuple): currentOntologyGraph = currentConceptOrRelation[0].getOntologyGraph() else: currentOntologyGraph = currentConceptOrRelation.getOntologyGraph() - + if currentOntologyGraph is not None: myOntologyGraphs.add(currentOntologyGraph) - + myilpOntSolver = ilpOntSolverFactory.getOntSolverInstance(myOntologyGraphs) - + if not myilpOntSolver: _DataNode__Logger.error("ILPSolver not initialized") raise DataNode.DataNodeError("ILPSolver not initialized") - + return myilpOntSolver, _conceptsRelations - + #----------------- Solver methods def collectInferredResults(self, concept, inferKey): @@ -1299,49 +1301,49 @@ def collectInferredResults(self, concept, inferKey): Args: concept (Concept or tuple): The concept for which to collect inferred results. inferKey (str): The type of inference, e.g., 'ILP', 'softmax', 'argmax'. - + Returns: torch.Tensor: Tensor containing collected attribute list. """ collectAttributeList = [] - + if not isinstance(concept, tuple): if not isinstance(concept, Concept): concept = self.findConcept(concept) if concept is None: return torch.tensor(collectAttributeList) - + if isinstance(concept, EnumConcept): concept = (concept, concept.name, None, len(concept.enum)) else: concept = (concept, concept.name, None, 1) - + rootConcept = self.findRootConceptOrRelation(concept[0]) - + if not rootConcept: return torch.tensor(collectAttributeList) - + rootConceptDns = self.findDatanodes(select = rootConcept) - + if not rootConceptDns: return torch.tensor(collectAttributeList) - + if getDnSkeletonMode() and "variableSet" in self.attributes: vKeyInVariableSet = rootConcept.name + "/<" + concept[0].name +">" - + # inferKey inferKeyInVariableSet = vKeyInVariableSet + "/" + inferKey - + if self.hasAttribute(inferKeyInVariableSet): return self.getAttribute(inferKeyInVariableSet) keys = [concept, inferKey] - + for dn in rootConceptDns: rTensor = dn.getAttribute(*keys) if rTensor is None: continue - + if torch.is_tensor(rTensor): if len(rTensor.shape) == 0 or len(rTensor.shape) == 1 and rTensor.shape[0] == 1: collectAttributeList.append(rTensor.item()) @@ -1357,43 +1359,43 @@ def collectInferredResults(self, concept, inferKey): collectAttributeList.append(1) else: collectAttributeList.append(0) - + if collectAttributeList and torch.is_tensor(collectAttributeList[0]): return torch.stack(tuple(collectAttributeList), dim=0) - - return torch.as_tensor(collectAttributeList) - + + return torch.as_tensor(collectAttributeList) + def infer(self): """Calculate argMax and softMax for the ontology-based data structure.""" - conceptsRelations = self.collectConceptsAndRelations() - + conceptsRelations = self.collectConceptsAndRelations() + for c in conceptsRelations: cRoot = self.findRootConceptOrRelation(c[0]) - + # ----- skeleton - tensor if getDnSkeletonMode() and "variableSet" in self.attributes: vKeyInVariableSet = cRoot.name + "/<" + c[0].name +">" - + # softmax softmaxKeyInVariableSet = vKeyInVariableSet + "/softmax" - + if not self.hasAttribute(softmaxKeyInVariableSet): vKeyInVariableSetValues = self.attributes["variableSet"][vKeyInVariableSet] if c[2] is not None: v = vKeyInVariableSetValues[:, c[2]] else: v = vKeyInVariableSetValues[:, 1] - + # check if v is None or not a tensor if v is None or not torch.is_tensor(v): continue - + if not(isinstance(v, torch.FloatTensor) or isinstance(v, torch.cuda.FloatTensor)): v = v.float() - + vSoftmaxT = torch.nn.functional.softmax(v, dim=-1) self.attributes["variableSet"][softmaxKeyInVariableSet] = vSoftmaxT - + # argmax argmaxKeyInVariableSet = vKeyInVariableSet + "/argmax" if not self.hasAttribute(argmaxKeyInVariableSet): @@ -1402,39 +1404,39 @@ def infer(self): v = vKeyInVariableSetValues[:, c[2]] else: v = vKeyInVariableSetValues[:, 1] - + vArgmaxTInxexes = torch.argmax(v, dim=-1) vArgmax = torch.zeros_like(v).scatter_(-1, vArgmaxTInxexes.unsqueeze(-1), 1.) - + self.attributes["variableSet"][argmaxKeyInVariableSet] = vArgmax - - # This is test + + # This is test if False: - dns = self.findDatanodes(select = cRoot) + dns = self.findDatanodes(select = cRoot) for dn in dns: keyArgmax = "<" + c[0].name + ">/argmax" keySoftMax = "<" + c[0].name + ">/softmax" - + index = c[2] if index is None: index = 1 - + s = dn.getAttribute(keySoftMax)[c[2]] a = dn.getAttribute(keyArgmax)[c[2]] continue - - else: - # ---- loop through dns + + else: + # ---- loop through dns dns = self.findDatanodes(select = cRoot) - + if not dns: continue - + vs = [] - + for dn in dns: v = dn.getAttribute(c[0]) - + if v is None: vs = [] break @@ -1450,263 +1452,263 @@ def infer(self): break else: vs.append(v[1]) - + if not vs: continue - + t = torch.tensor(vs) t[torch.isnan(t)] = 0 # NAN -> 0 - + vM = torch.argmax(t).item() # argmax - + # Elements for softmax tExp = torch.exp(t) tExpSum = torch.sum(tExp).item() - + keyArgmax = "<" + c[0].name + ">/argmax" keySoftMax = "<" + c[0].name + ">/softmax" - + # Add argmax and softmax to DataNodes - for dn in dns: + for dn in dns: if keyArgmax not in dn.attributes: dn.attributes[keyArgmax] = torch.empty(c[3], dtype=torch.float) - + if dn.getInstanceID() == vM: dn.attributes[keyArgmax][c[2]] = 1 else: dn.attributes[keyArgmax][c[2]] = 0 - + if keySoftMax not in dn.attributes: dn.attributes[keySoftMax] = torch.empty(c[3], dtype=torch.float) - + dnSoftmax = tExp[dn.getInstanceID()]/tExpSum dn.attributes[keySoftMax][c[2]] = dnSoftmax.item() def inferLocal(self, keys=("softmax", "argmax"), Acc=None): """ Infer local probabilities and information for given concepts and relations. - + Args: keys (tuple): Tuple containing the types of information to infer ('softmax', 'argmax', etc.). Acc (dict, optional): A dictionary containing some form of accumulated data for normalization. - + Attributes affected: - This function manipulates the 'attributes' dictionary attribute of the class instance. - + Notes: - The method uses PyTorch for tensor operations. - Logging is done to capture the time taken for inferring local probabilities. """ startInferLocal = perf_counter() - # Go through keys and remove anything from each of them which is before slash, including slash + # Go through keys and remove anything from each of them which is before slash, including slash keys = [key[key.rfind('/')+1:] for key in keys] - - conceptsRelations = self.collectConceptsAndRelations() - + + conceptsRelations = self.collectConceptsAndRelations() + normalized_keys = set([ - "normalizedProb", "meanNormalizedProb", + "normalizedProb", "meanNormalizedProb", "normalizedProbAll", "meanNormalizedProbStd", "normalizedProbAcc", "entropyNormalizedProbAcc", "normalizedJustAcc", ]) - + if "softmax" in keys or normalized_keys.intersection(set(keys)): needSoftmax = True else: needSoftmax = False - + for c in conceptsRelations: cRoot = self.findRootConceptOrRelation(c[0]) inferLocalKeys = list(keys) # used to check if all keys are calculated - + # ----- skeleton - tensor if getDnSkeletonMode() and "variableSet" in self.attributes: - + vKeyInVariableSet = cRoot.name + "/<" + c[0].name +">" - + if needSoftmax: localSoftmaxKeyInVariableSet = vKeyInVariableSet + "/local/softmax" - + if "softmax" in inferLocalKeys: inferLocalKeys.remove("softmax") - + if not self.hasAttribute(localSoftmaxKeyInVariableSet): v = self.attributes["variableSet"][vKeyInVariableSet] - + # check if v is None or not a tensor if v is None or not torch.is_tensor(v): continue - + if not(isinstance(v, torch.FloatTensor) or isinstance(v, torch.cuda.FloatTensor)): v = v.float() - + vSoftmaxT = torch.nn.functional.softmax(v, dim=-1) self.attributes["variableSet"][localSoftmaxKeyInVariableSet] = vSoftmaxT - + if "argmax" in keys: localArgmaxKeyInVariableSet = vKeyInVariableSet + "/local/argmax" inferLocalKeys.remove("argmax") - + if not self.hasAttribute(localArgmaxKeyInVariableSet): v = self.attributes["variableSet"][vKeyInVariableSet] - + vArgmaxTInxexes = torch.argmax(v, dim=1) vArgmax = torch.zeros_like(v).scatter_(1, vArgmaxTInxexes.unsqueeze(1), 1.) - + self.attributes["variableSet"][localArgmaxKeyInVariableSet] = vArgmax - + # check if we already processed all keys using skeleton if not inferLocalKeys: continue - + # ---- loop through dns dns = self.findDatanodes(select = cRoot) if not dns: continue - + vs = [] - + for dn in dns: if needSoftmax: keySoftmax = "<" + c[0].name + ">/local/softmax" - if not dn.hasAttribute(keySoftmax): + if not dn.hasAttribute(keySoftmax): v = dn.getAttribute(c[0]) - + # check if v is None or not a tensor if v is None or not torch.is_tensor(v): continue - + if not(isinstance(v, torch.FloatTensor) or isinstance(v, torch.cuda.FloatTensor)): v = v.float() - + vSoftmaxT = torch.nn.functional.softmax(v, dim=-1) dn.attributes[keySoftmax] = vSoftmaxT.squeeze(0) - + if "normalizedProb" in keys: keyNormalizedProb = "<" + c[0].name + ">/local/normalizedProb" - if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? + if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? vSoftmaxT = dn.getAttribute(keySoftmax) - + # Clamps the softmax probabilities - vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18) - + vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18) + # Calculates their entropy; entropy = torch.distributions.Categorical(torch.log(vector)).entropy() / vector.shape[0] - + # Multiplies the reverse of entropy to the vector divided by its mean value. P vNormalizedProbT = (1/entropy.item()) * (vector/torch.mean(vector)) - + dn.attributes[keyNormalizedProb] = vNormalizedProbT if "normalizedProbAcc" in keys: keyNormalizedProb = "<" + c[0].name + ">/local/normalizedProbAcc" - if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? + if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? vSoftmaxT = dn.getAttribute(keySoftmax) # Clamps the softmax probabilities - vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18) - + vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18) + ### Calculate the multiplier factor if Acc and c[0].name in Acc: multiplier = pow(Acc[c[0].name], 4) else: multiplier = 1 - + # Calculates their entropy; entropy = torch.distributions.Categorical(torch.log(vector)).entropy() / vector.shape[0] - + # Multiplies the reverse of entropy to the vector divided by its mean value. P vNormalizedProbT = (1/entropy.item()) * (vector/torch.mean(vector)) if multiplier != 1: vNormalizedProbT = vNormalizedProbT * multiplier - + dn.attributes[keyNormalizedProb] = vNormalizedProbT if "entropyNormalizedProbAcc" in keys: keyNormalizedProb = "<" + c[0].name + ">/local/entropyNormalizedProbAcc" - if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? + if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? vSoftmaxT = dn.getAttribute(keySoftmax) # Clamps the softmax probabilities - vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18) - + vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18) + ### Calculate the multiplier factor if Acc and c[0].name in Acc: multiplier = pow(Acc[c[0].name], 4) else: multiplier = 1 - + # Calculates their entropy; entropy = torch.distributions.Categorical(torch.log(vector)).entropy() / vector.shape[0] - + # Multiplies the reverse of entropy to the vector divided by its mean value. P vNormalizedProbT = (1/entropy.item()) * vector if multiplier != 1: vNormalizedProbT = vNormalizedProbT * multiplier - + dn.attributes[keyNormalizedProb] = vNormalizedProbT if "normalizedJustAcc" in keys: keyNormalizedProb = "<" + c[0].name + ">/local/normalizedJustAcc" - if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? + if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? vSoftmaxT = dn.getAttribute(keySoftmax) - + ### Calculate the multiplier factor if Acc and c[0].name in Acc: multiplier = pow(Acc[c[0].name], 8) else: multiplier = 1 - + # Calculates their entropy; - + # Multiplies the reverse of entropy to the vector divided by its mean value. P vNormalizedProbT = vSoftmaxT if multiplier != 1: vNormalizedProbT = vNormalizedProbT * multiplier - + dn.attributes[keyNormalizedProb] = vNormalizedProbT if "meanNormalizedProb" in keys: keyNormalizedProb = "<" + c[0].name + ">/local/meanNormalizedProb" - if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? + if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? vSoftmaxT = dn.getAttribute(keySoftmax) vector = vSoftmaxT - + # Multiplies the reverse of entropy to the vector divided by its mean value. P vNormalizedProbT = vector/torch.mean(vector) - + dn.attributes[keyNormalizedProb] = vNormalizedProbT if "normalizedProbAll" in keys: keyNormalizedProb = "<" + c[0].name + ">/local/normalizedProbAll" - if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? + if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? vSoftmaxT = dn.getAttribute(keySoftmax) # Clamps the softmax probabilities - vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18) - + vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18) + # Calculates their entropy; entropy = torch.distributions.Categorical(torch.log(vector)).entropy() / vector.shape[0] - + signs = vector - torch.mean(vector) signs[signs < 0] = -1 signs[signs >= 0] = +1 adjustment = signs * torch.pow(vector - torch.mean(vector), 4) - + # Multiplies the reverse of entropy to the vector divided by its mean value. P vNormalizedProbT = (1/entropy.item()) * (vector/torch.mean(vector)) + adjustment - + dn.attributes[keyNormalizedProb] = vNormalizedProbT if "meanNormalizedProbStd" in keys: keyNormalizedProb = "<" + c[0].name + ">/local/meanNormalizedProbStd" - if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? + if not dn.hasAttribute(keyNormalizedProb): # Already calculated ? vSoftmaxT = dn.getAttribute(keySoftmax) vector = vSoftmaxT @@ -1715,12 +1717,12 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None): signs[signs < 0] = -1 signs[signs >= 0] = +1 adjustment = signs * torch.pow(vector - torch.mean(vector), 2) - + # Multiplies the reverse of entropy to the vector divided by its mean value. P vNormalizedProbT = (adjustment/torch.pow(torch.mean(vector), 2)) - + dn.attributes[keyNormalizedProb] = vNormalizedProbT - + if "argmax" in keys: keyArgmax = "<" + c[0].name + ">/local/argmax" if not dn.hasAttribute(keyArgmax): @@ -1729,18 +1731,18 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None): vArgmaxCalculated = torch.argmax(v, keepdim=True) vArgmaxIndex = torch.argmax(v).item() vArgmax[vArgmaxIndex] = 1 - + dn.attributes[keyArgmax] = vArgmax - + endInferLocal = perf_counter() elapsedInferLocalInMs = (endInferLocal - startInferLocal) * 1000 self.myLoggerTime.info('Infer Local Probabilities - keys: %s, time: %dms', keys, elapsedInferLocalInMs) - + def inferILPResults(self, *_conceptsRelations, key=("local", "softmax"), fun=None, epsilon=0.00001, minimizeObjective=False, ignorePinLCs=False, Acc=None): """ Calculate ILP (Integer Linear Programming) prediction for a data graph using this instance as the root. Based on the provided list of concepts and relations, it initiates ILP solving procedures. - + Parameters: - *_conceptsRelations: tuple The concepts and relations used for inference. @@ -1756,10 +1758,10 @@ def inferILPResults(self, *_conceptsRelations, key=("local", "softmax"), fun=Non Whether to ignore pin constraints, default is False. - Acc: object, optional An accumulator for collecting results, default is None. - + Raises: - DataNodeError: When no concepts or relations are found for inference. - + Returns: - None: This function operates in-place and does not return a value. """ @@ -1767,38 +1769,38 @@ def inferILPResults(self, *_conceptsRelations, key=("local", "softmax"), fun=Non _DataNode__Logger.info('Called with empty list of concepts and relations for inference.') else: _DataNode__Logger.info(f'Called with the following list of concepts and relations for inference: {[x.name if isinstance(x, Concept) else x for x in _conceptsRelations]}') - + # Check if a full data node is created; if not, create it as it's needed for ILP inference if self.myBuilder: self.myBuilder.createFullDataNode(self) - + # Collect all relevant concepts and relations from the data graph _conceptsRelations = self.collectConceptsAndRelations(_conceptsRelations) - + if not _conceptsRelations: _DataNode__Logger.error(f'No concepts or relations found for inference in the provided DataNode {self}.') raise DataNode.DataNodeError(f'No concepts or relations found for inference in the provided DataNode {self}.') - else: + else: _DataNode__Logger.info(f'Found the following set of concepts and relations for inference: {[x[1] if isinstance(x, tuple) else x for x in _conceptsRelations]}') - + myILPOntSolver, conceptsRelations = self.getILPSolver(_conceptsRelations) - + _DataNode__Logger.info("Initiating ILP solver") - + if "local" in key: keys = (key[1],) self.inferLocal(keys=keys, Acc=Acc) - + startILPInfer = perf_counter() if self.graph.batch and self.ontologyNode == self.graph.batch and 'contains' in self.relationLinks: batchConcept = self.graph.batch self.myLoggerTime.info(f'Batch processing ILP for {batchConcept}') - + for batchIndex, dn in enumerate(self.relationLinks['contains']): startILPBatchStepInfer = perf_counter() myILPOntSolver.calculateILPSelection(dn, *conceptsRelations, key=key, fun=fun, epsilon=epsilon, minimizeObjective=minimizeObjective, ignorePinLCs=ignorePinLCs) endILPBatchStepInfer = perf_counter() - + elapsed = endILPBatchStepInfer - startILPBatchStepInfer if elapsed > 1: self.myLoggerTime.info(f'Finished step {batchIndex} for batch ILP Inference - time: {elapsed:.2f}s') @@ -1806,34 +1808,34 @@ def inferILPResults(self, *_conceptsRelations, key=("local", "softmax"), fun=Non self.myLoggerTime.info(f'Finished step {batchIndex} for batch ILP Inference - time: {elapsed*1000:.2f}ms') else: myILPOntSolver.calculateILPSelection(self, *conceptsRelations, key=key, fun=fun, epsilon=epsilon, minimizeObjective=minimizeObjective, ignorePinLCs=ignorePinLCs) - + endILPInfer = perf_counter() - + elapsed = endILPInfer - startILPInfer if elapsed > 1: self.myLoggerTime.info(f'Completed ILP Inference - total time: {elapsed:.2f}s') else: self.myLoggerTime.info(f'Completed ILP Inference - total time: {elapsed*1000:.2f}ms') - + self.myLoggerTime.info('') - + def inferGBIResults(self, *_conceptsRelations, model, kwargs): """ Infer Grounded Belief Inference (GBI) results based on given concepts and relations. - + Parameters: - _conceptsRelations: tuple or list Concepts and relations for which GBI is to be calculated. If empty, collects all from the graph. - model: object Solver model to be used in the GBI calculation. - + Returns: None. The function modifies the state of the `self.graph` object to store GBI results. - + Logging: - Logs whether the function was called with an empty or non-empty list of concepts and relations. - Logs other debug and informational messages. - + Side Effects: - Modifies the state of the `self.graph` object to store GBI results. """ @@ -1841,36 +1843,36 @@ def inferGBIResults(self, *_conceptsRelations, model, kwargs): _DataNode__Logger.info('Called with empty list of concepts and relations for inference') else: _DataNode__Logger.info('Called with - %s - list of concepts and relations for inference'%([x.name if isinstance(x, Concept) else x for x in _conceptsRelations])) - + # Check if concepts and/or relations have been provided for inference, if provide translate then to tuple concept info form _conceptsRelations = self.collectConceptsAndRelations(_conceptsRelations) # Collect all concepts and relations from graph as default set from domiknows.program.model.gbi import GBIModel from inspect import signature cmodelSignature = signature(GBIModel.__init__) - + cmodelKwargs = {} for param in cmodelSignature.parameters.values(): paramName = param.name if paramName in kwargs: cmodelKwargs[paramName] = kwargs[paramName] - + myGBIModel = GBIModel(self.graph, solver_model=model, **cmodelKwargs) myGBIModel.calculateGBISelection(self, _conceptsRelations) - + def verifyResultsLC(self, key = "/local/argmax"): """ - Verify the results of ILP (Integer Linear Programming) by checking the percentage of + Verify the results of ILP (Integer Linear Programming) by checking the percentage of results satisfying each logical constraint (LC). - + Parameters: - key: str, optional Specifies the method used for verification. Supported keys are those containing "local" or "ILP". Default is "/local/argmax". - + Raises: - DataNodeError: When an unsupported key is provided. - + Returns: - verifyResult: object The result of the verification, typically a data structure containing percentages of @@ -1881,25 +1883,25 @@ def verifyResultsLC(self, key = "/local/argmax"): # Check if full data node is created and create it if not if self.myBuilder != None: self.myBuilder.createFullDataNode(self) - + if "local" in key: - self.inferLocal(keys=[key]) + self.inferLocal(keys=[key]) elif "ILP" in key: self.infer() else: _DataNode__Logger.error("Not supported key %s for verifyResultsLC"%(key)) - + verifyResult = myilpOntSolver.verifyResultsLC(self, key = key) - + return verifyResult - - def calculateLcLoss(self, tnorm='P', sample=False, sampleSize=0, sampleGlobalLoss=False): + + def calculateLcLoss(self, tnorm='P',counting_tnorm=None, sample=False, sampleSize=0, sampleGlobalLoss=False): """ Calculate the loss for logical constraints (LC) based on various t-norms. - + Parameters: - tnorm: str, optional - Specifies the t-norm used for calculations. Supported t-norms are 'L' (Lukasiewicz), + Specifies the t-norm used for calculations. Supported t-norms are 'L' (Lukasiewicz), 'G' (Godel), and 'P' (Product). Default is 'P'. - sample: bool, optional Specifies whether sampling is to be used. Default is False. @@ -1908,30 +1910,30 @@ def calculateLcLoss(self, tnorm='P', sample=False, sampleSize=0, sampleGlobalLos Default is 0. - sampleGlobalLoss: bool, optional Specifies whether to calculate the global loss in case of sampling. Default is False. - + Returns: - lcResult: object The calculated loss for logical constraints, typically a numeric value or data structure. - + Raises: - DataNodeError: When an unsupported tnorm is provided or other internal errors occur. """ self.myBuilder.createFullDataNode(self) - + myilpOntSolver, conceptsRelations = self.getILPSolver(conceptsRelations=self.collectConceptsAndRelations()) - + self.inferLocal() - lcResult = myilpOntSolver.calculateLcLoss(self, tnorm=tnorm, sample=sample, - sampleSize=sampleSize, sampleGlobalLoss=sampleGlobalLoss, + lcResult = myilpOntSolver.calculateLcLoss(self, tnorm=tnorm,counting_tnorm=counting_tnorm, sample=sample, + sampleSize=sampleSize, sampleGlobalLoss=sampleGlobalLoss, conceptsRelations=conceptsRelations) - + return lcResult def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, average='binary'): """ Calculate inference metrics for given concepts and relations. - + Parameters: - conceptsRelations: tuple or list Concepts and relations for which metrics are to be calculated. If empty, it collects all. @@ -1941,11 +1943,11 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av Weight tensor to be used in the calculation. - average: str, optional (default is 'binary') Type of average to be used in metrics calculation. Can be 'binary', 'micro', etc. - + Returns: - result: dict Dictionary containing calculated metrics (TP, FP, TN, FN, P, R, F1) for each concept. - + Logging: - Various logs are printed for debugging and information. """ @@ -1955,27 +1957,27 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av _DataNode__Logger.info("Found conceptsRelations in DataNode- %s"%(conceptsRelations)) else: _DataNode__Logger.info("Calling %s metrics with conceptsRelations - %s"%(inferType, conceptsRelations)) - + weightOriginal = weight if weight is None: weight = torch.tensor(1) else: _DataNode__Logger.info("Using weight %s"%(weight)) - - # Will store calculated metrics an related data - result = {} - tp, fp, tn, fn = [], [], [], [] - isBinary = False + + # Will store calculated metrics an related data + result = {} + tp, fp, tn, fn = [], [], [], [] + isBinary = False isMulticlass = False isMulticlassLabel = False - + # Calculate metrics for each provided concept for cr in conceptsRelations: # Check format of concepts and translate them to tuple in order to accommodate multiclass concepts if not isinstance(cr, tuple): # Not tuple concept form yet if not isinstance(cr, Concept): # If string find the corresponding concept cr = self.findConcept(cr) - + if cr is None: # Sting mapping to concept is not found _DataNode__Logger.error("% string is not a concept - not able to calculate metrics"%(cr)) continue @@ -1986,7 +1988,7 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av else: _DataNode__Logger.error("% string is not a concept - not able to calculate metrics"%(cr)) continue - + _DataNode__Logger.info("Calculating metrics for concept %s"%(cr[0])) # Collect date for metrics from DataNode @@ -2005,28 +2007,28 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av continue else: _DataNode__Logger.info("Concept %s labels from DataNode %s"%(cr[1], labelsR)) - + if not torch.is_tensor(preds): _DataNode__Logger.error("Concept %s labels is not a Tensor - not able to calculate metrics"%(cr[1])) continue - + if not torch.is_tensor(labelsR): _DataNode__Logger.error("Concept %s predictions is not a Tensor - not able to calculate metrics"%(cr[1])) continue - + # Move to CPU if preds.is_cuda: preds = preds.cpu() if labelsR.is_cuda: labelsR = labelsR.cpu() - + # Translate labels - if provided as True/False to long labels = torch.clone(labelsR) labels = labels.long() preds = preds.long() - + # -- Multiclass processing - + # Check if concept is a label from Multiclass - if cr[2] is not None: # Multiclass label given multiclass index (cr[2]) + if cr[2] is not None: # Multiclass label given multiclass index (cr[2]) isMulticlassLabel = True average = None labelsList = [i for i in range(cr[3])] @@ -2039,17 +2041,17 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av if preds.shape[0] == len(labelsR): predsOriginal = preds preds = torch.nonzero(preds, as_tuple=True)[1] - + if preds.shape[0] != len(labelsR): _DataNode__Logger.warning("Concept %s predictions tensor has some predictions not calculated - %s"%(cr[1], predsOriginal)) - + _DataNode__Logger.info("Concept %s is Multiclass "%(cr[1])) _DataNode__Logger.info("Using average %s for Multiclass metrics calculation"%(average)) else: _DataNode__Logger.error("Incompatible lengths for %s between inferred results %s and labels %s"%(cr[1], len(preds), len(labelsR))) continue - + _DataNode__Logger.info("Calculating metrics for all class Labels of %s "%(cr[1])) multiclassLabels = cr[0].enum result = self.getInferMetrics(*multiclassLabels, inferType=inferType, weight = weightOriginal, average=average) @@ -2058,34 +2060,34 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av labelsList = None # --- - + # Check if date prepared correctly if preds.dim() != 1: _DataNode__Logger.error("Concept %s predictions is Tensor with dimension %s > 1- not able to calculate metrics"%(cr[1], preds.dim())) continue - + if labels.dim() != 1: _DataNode__Logger.error("Concept %s labels is Tensor with dimension %s > 1- not able to calculate metrics"%(cr[1], labels.dim())) continue - + if preds.size()[0] != labels.size()[0]: _DataNode__Logger.error("Concept %s labels size %s is not equal to prediction size %s - not able to calculate metrics"%(cr[1], labels.size()[0], preds.size()[0])) continue - + # Prepare the metrics result storage result[cr[1]] = {'cr': cr, 'inferType' : inferType, 'TP': torch.tensor(0.), 'FP': torch.tensor(0.), 'TN': torch.tensor(0.), 'FN': torch.tensor(0.)} - + # To numpy for sklearn - labels = labels.numpy() + labels = labels.numpy() preds = preds.numpy() - + import numpy as np if np.sum(labels) == 0: _DataNode__Logger.warning("Concept %s - found all zero labels %s"%(cr[1], labels)) else: _DataNode__Logger.info("Concept %s - labels used for metrics calculation %s"%(cr[1], labels)) result[cr[1]]['labels'] = labels - + if np.sum(preds) == 0: _DataNode__Logger.warning("Concept %s - found all zero predictions %s"%(cr[1], preds)) else: @@ -2102,26 +2104,26 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av elif isBinary: cm = metrics.confusion_matrix(labels, preds) _tn, _fp, _fn, _tp = cm.ravel() - - tp.append(_tp) - result[cr[1]]['TP'] = _tp # true positive - + + tp.append(_tp) + result[cr[1]]['TP'] = _tp # true positive + fp.append(_fp) result[cr[1]]['FP'] = _fp # false positive - + tn.append(_tn) result[cr[1]]['TN'] = _tn # true negative - + fn.append(_fn) result[cr[1]]['FN'] = _fn # false positive else: pass - + result[cr[1]]['confusion_matrix'] = cm _DataNode__Logger.info("Concept %s confusion matrix %s"%(cr[1], result[cr[1]]['confusion_matrix'])) except ValueError as ve: # Error when both labels and preds as zeros _DataNode__Logger.warning("Concept %s - both labels and predictions are all zeros - not able to calculate confusion metrics"%(cr[1])) - + # Calculate precision P - tp/(tp + fp) _p = metrics.precision_score(labels, preds, average=average, labels=labelsList, zero_division=0) # precision or positive predictive value (PPV) if isMulticlassLabel: @@ -2141,7 +2143,7 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av _DataNode__Logger.warning("Concept %s recall %s"%(cr[1], _r)) else: _DataNode__Logger.info("Concept %s recall %s"%(cr[1], _r)) - + # Calculate F1 score - (P X R)/(P + R) _f1 = metrics.f1_score(labels, preds, average=average, labels=labelsList, zero_division=0) # f1 if isMulticlassLabel: @@ -2154,31 +2156,31 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av # --- Calculate Total metrics for binary concept if isBinary: - result['Total'] = {} + result['Total'] = {} tpT = (torch.tensor(tp)).sum() - result['Total']['TP'] = tpT - fpT = (torch.tensor(fp)).sum() + result['Total']['TP'] = tpT + fpT = (torch.tensor(fp)).sum() result['Total']['FP'] = fpT - tnT = (torch.tensor(tn)).sum() + tnT = (torch.tensor(tn)).sum() result['Total']['TN'] = tnT - fnT = (torch.tensor(fn)).sum() + fnT = (torch.tensor(fn)).sum() result['Total']['FN'] = fnT - + if tpT + fpT: - pT = tpT / (tpT + fpT) + pT = tpT / (tpT + fpT) result['Total']['P'] = pT if pT == 0: _DataNode__Logger.warning("Total precision is %s"%(pT)) else: _DataNode__Logger.info("Total precision is %s"%(pT)) - + rT = tpT / (tpT + fnT) result['Total']['R'] = rT if rT == 0: _DataNode__Logger.warning("Total recall is %s"%(rT)) else: _DataNode__Logger.info("Total recall is %s"%(rT)) - + if pT + rT: f1T = 2 * pT * rT / (pT + rT) result['Total']['F1'] = f1T @@ -2186,7 +2188,7 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av _DataNode__Logger.warning("Total F1 is %s"%(f1T)) else: _DataNode__Logger.info("Total F1 is %s"%(f1T)) - + elif tpT + (fpT + fnT)/2: f1T = tpT/(tpT + (fpT + fnT)/2) result['Total']['F1'] = f1T @@ -2195,23 +2197,23 @@ def getInferMetrics(self, *conceptsRelations, inferType='ILP', weight = None, av else: _DataNode__Logger.info("Total F1 is %s"%(f1T)) else: - _DataNode__Logger.warning("No able to calculate F1 for Total") + _DataNode__Logger.warning("No able to calculate F1 for Total") else: result['Total'] = {"No Total metrics for multiclass concept"} return result - + # Class constructing the data graph based on the sensors data during the model execution class DataNodeBuilder(dict): """ DataNodeBuilder class that extends Python's built-in dictionary. - + Attributes: - context (str): The context in which the DataNodeBuilder is being used, defaults to "build". - myLoggerTime: Logger time instance for logging purposes. - skeletonDataNode: Data structure for the basic DataNode skeleton. - skeletonDataNodeFull: Data structure for the full DataNode skeleton. - + Methods: - __init__: Initializes the DataNodeBuilder instance. - __getitem__: Overrides dict's __getitem__ to fetch item for a given key. @@ -2219,15 +2221,15 @@ class DataNodeBuilder(dict): """ context = "build" - + def __init__(self, *args, **kwargs): """ Initialize the DataNodeBuilder instance. - + Parameters: - args: Positional arguments to pass to the dict constructor. - kwargs: Keyword arguments to pass to the dict constructor. - + Side Effects: - Logs the instance creation. - Initializes various instance variables. @@ -2243,17 +2245,17 @@ def __init__(self, *args, **kwargs): dict.__setitem__(self, "DataNodesConcepts", {}) dict.__setitem__(self, "KeysInOrder", []) - + if args: dict.__setitem__(self, "data_item", args[0]) def __getitem__(self, key): """ Override dictionary's __getitem__ to fetch item for a given key. - + Parameters: - key: The key to look for in the dictionary. - + Returns: The value associated with the provided key. """ @@ -2262,10 +2264,10 @@ def __getitem__(self, key): def __changToTuple(self, v): """ Change elements of value to tuple if they are list, in order to use the value as dictionary keys. - + Parameters: - v: The value to be converted. - + Returns: The value converted to tuple form if it was a list; otherwise, the original value. """ @@ -2273,49 +2275,49 @@ def __changToTuple(self, v): _v = [] for v1 in v: _v.append(self.__changToTuple(v1)) - + return tuple(_v) else: return v - + def __addVariableNameToSet(self, vName): """ Add a variable name to the internal 'variableSet'. - + This method checks if 'variableSet' exists in the dictionary. If it does not exist, it is created. The provided variable name is then added to this set. - + Args: vName (str): The variable name to add to the set. """ variableSetName = 'variableSet' if not dict.__contains__(self, variableSetName): dict.__setitem__(self, variableSetName, set()) - + variableSet = dict.__getitem__(self, variableSetName) variableSet.add(vName) - + def __addPropertyNameToSet(self, pName): """ Add a property name to the internal 'propertySet'. - + This method checks if 'propertySet' exists in the dictionary. If it does not exist, it is created. The provided property name is then added to this set. - + Args: pName (str): The property name to add to the set. """ propertySetName = 'propertySet' if not dict.__contains__(self, propertySetName): dict.__setitem__(self, propertySetName, set()) - + propertySet = dict.__getitem__(self, propertySetName) propertySet.add(pName) - + def __addSetitemCounter(self): """ Add or increment a global counter for the number of '__setitem__' calls. - + This method checks if a global counter (named 'Counter_setitem') exists in the dictionary. If it does not exist, it is created and initialized with 1. If it exists, it is incremented by 1. """ @@ -2328,43 +2330,43 @@ def __addSetitemCounter(self): def __addSensorCounters(self, skey, value): """ - Add or increment a sensor-specific counter for the number of '__setitem__' calls + Add or increment a sensor-specific counter for the number of '__setitem__' calls with the given sensor key and value. - - This method constructs a unique counter name based on the sensor key ('skey'). - If the counter doesn't exist, it's created and initialized with the given 'value'. + + This method constructs a unique counter name based on the sensor key ('skey'). + If the counter doesn't exist, it's created and initialized with the given 'value'. If the counter already exists, it's incremented. - + A flag named 'recent' is also used to indicate whether the counter was recently incremented. - + Args: skey (list): The list of keys representing the sensor. Used to construct the unique counter name. value (Any): The value to be counted. If it's a list, it will be converted to a tuple. - + Returns: bool: True if the counter for the given '_value' was recently incremented, False otherwise. """ _value = value if isinstance(value, list): _value = self.__changToTuple(_value) - + counterNanme = 'Counter' for s in skey: # skey[2:]: counterNanme = counterNanme + '/' + s - + if not dict.__contains__(self, counterNanme): try: dict.__setitem__(self, counterNanme, {_value : {"counter": 1, "recent" : True}}) except TypeError: return False - + return False else: currentCounter = dict.__getitem__(self, counterNanme) - + if _value in currentCounter: - currentCounter[_value]["counter"] = currentCounter[_value]["counter"] + 1 - + currentCounter[_value]["counter"] = currentCounter[_value]["counter"] + 1 + if currentCounter[_value]["recent"]: return True else: @@ -2372,39 +2374,39 @@ def __addSensorCounters(self, skey, value): return False else: currentCounter[_value] = {"counter": 1, "recent" : True} - + return False - + def __findConcept(self, conceptName, usedGraph): """ Search for a concept in the graph based on its name. - + Args: conceptName (str): The name of the concept to search for. usedGraph (Graph object): The graph object where to look for the concept. - + Returns: Concept object: The concept object if found, otherwise None. """ subGraph_keys = [key for key in usedGraph._objs] for subGraphKey in subGraph_keys: subGraph = usedGraph._objs[subGraphKey] - + for conceptNameItem in subGraph.concepts: if conceptName == conceptNameItem: concept = subGraph.concepts[conceptNameItem] - + return concept - return None - + return None + def __findConceptInfo(self, usedGraph, concept): """ Collects and returns information about a given concept as defined in the graph. - + Args: usedGraph (Graph object): The graph object where to look for the concept information. concept (Concept object): The concept object for which information is to be collected. - + Returns: dict: A dictionary containing various pieces of information about the concept. - 'concept': The concept itself. @@ -2424,16 +2426,16 @@ def __findConceptInfo(self, usedGraph, concept): } return conceptInfo - + def __updateConceptInfo(self, usedGraph, conceptInfo, sensor): """ Updates concept information based on the given sensor. - + Args: usedGraph (Graph object): The graph object where to look for the concept. conceptInfo (dict): The existing dictionary containing information about the concept. sensor (Sensor object): The sensor object that is being processed. - + Attributes Updated in conceptInfo dictionary: - 'relationAttrData': A boolean indicating if the destination attribute is equal to the concept. - 'label': A boolean indicating if the sensor has a label attribute and it is set. @@ -2442,30 +2444,30 @@ def __updateConceptInfo(self, usedGraph, conceptInfo, sensor): - 'relationAttrsGraph': Copy of existing 'relationAttrs' if present. - 'relationAttrs': A dictionary with updated source and destination attributes. - 'relationMode': The mode of the relation. - - Note: + + Note: - This method uses `EdgeSensor` from domiknows.sensor.pytorch.relation_sensors for certain operations. - The method updates the 'conceptInfo' dictionary in-place. """ from domiknows.sensor.pytorch.relation_sensors import EdgeSensor conceptInfo["relationAttrData"] = False conceptInfo['label'] = False - if hasattr(sensor, 'label') and sensor.label: + if hasattr(sensor, 'label') and sensor.label: conceptInfo['label'] = True if (isinstance(sensor, EdgeSensor)): - + conceptInfo['relationName'] = sensor.relation.name conceptInfo['relationTypeName'] = str(type(sensor.relation)) - + if 'relationAttrs' in conceptInfo: conceptInfo['relationAttrsGraph'] = conceptInfo['relationAttrs'] - + conceptInfo['relationAttrs'] = {} - + conceptInfo['relationMode'] = sensor.relation.mode - conceptInfo['relationAttrs']["src"] = self.__findConcept(sensor.src.name, usedGraph) - conceptInfo['relationAttrs']["dst"] = self.__findConcept(sensor.dst.name, usedGraph) + conceptInfo['relationAttrs']["src"] = self.__findConcept(sensor.src.name, usedGraph) + conceptInfo['relationAttrs']["dst"] = self.__findConcept(sensor.dst.name, usedGraph) if conceptInfo['relationAttrs']["dst"] == conceptInfo['concept']: conceptInfo['relationAttrData'] = True @@ -2473,53 +2475,53 @@ def __updateConceptInfo(self, usedGraph, conceptInfo, sensor): def __isRootDn(self, testedDn, checkedDns, visitedDns): """ Determine if a given DataNode (testedDn) is a root node in the graph based on its impactLinks. - + Args: testedDn (DataNode): The DataNode object that is being tested for its 'root' status. checkedDns (set): A set of DataNodes that have already been examined or should be considered for this check. visitedDns (set, optional): A set of DataNodes that have already been visited during recursion to avoid infinite loops. - + Returns: bool: Returns True if the testedDn is a root node, False otherwise. - + Note: - The method is recursive and visits each node only once to avoid infinite loops. - 'impactLinks' is an attribute of DataNode that shows which DataNodes impact the current DataNode. """ if visitedDns == None: visitedDns = set() - + visitedDns.add(testedDn) - + if not testedDn.impactLinks and testedDn in checkedDns: return False - - isRoot = True + + isRoot = True for _, iDnList in testedDn.impactLinks.items(): # Check if its impacts are connected to Dn in the new Root list if iDnList: for iDn in iDnList: if iDn in visitedDns: continue - + if self.__isRootDn(iDn, checkedDns, visitedDns): isRoot = False break - + if not isRoot: break - + return isRoot - + def __updateRootDataNodeList(self, *dns): """ Update the list of root dataNodes in the dictionary based on newly added dataNodes and existing ones. - + Args: dns (tuple): A tuple containing the dataNodes to be added to the root list. It can contain nested lists. - + Returns: None: The function updates the list of root dataNodes in place and doesn't return any value. - + Notes: - The function first identifies existing root dataNodes and then updates this list based on the new ones. - It uses the `impactLinks` attribute of dataNodes to determine whether a dataNode should be considered a root. @@ -2527,14 +2529,14 @@ def __updateRootDataNodeList(self, *dns): """ if not dns: return - + # Get existing roots dataNodes if dict.__contains__(self, 'dataNode'): dnsRoots = dict.__getitem__(self, 'dataNode') _DataNodeBuilder__Logger.debug('Existing elements in the root dataNodes list - %s'%(dnsRoots)) else: dnsRoots = [] - + # First flatten the list of new dataNodes def flatten(dns): for dn in dns: @@ -2545,16 +2547,16 @@ def flatten(dns): # Flatten the list of new dataNodes flattenDns = list(flatten(dns)) - + # Create a set of all unique dataNodes in dnsRoots and flattenDns allDns = set(dnsRoots) allDns.update(flattenDns) - - # -- Update list of existing root dataNotes - + + # -- Update list of existing root dataNotes + # Will be used to store new root dataNodes newDnsRoots = [] - + # Loop over all known unique dataNodes #for dnE in allDns: # Check if the dataNode is a root dataNode because it has no impact link @@ -2566,7 +2568,7 @@ def flatten(dns): # # Check if the current dataNode is still a root dataNode # if self.__isRootDn(dnE, dnsRoots, visitedDns = None): # newDnsRoots.append(dnE) - + # Count the number of incoming links for each dataNode incomingLinks = {dn: 0 for dn in allDns} dnTypes = {} @@ -2575,13 +2577,13 @@ def flatten(dns): dnTypes[dn.ontologyNode].append(dn) else: dnTypes[dn.ontologyNode] = [dn] - + for il in dn.impactLinks: if il in incomingLinks: incomingLinks[dn] += 1 else: incomingLinks[dn] = 1 - + # Find the root dataNodes which have no incoming links newDnsRoots = [dn for dn in allDns if incomingLinks[dn] == 0 or not dn.impactLinks] newDnsRoots = sorted(newDnsRoots, key=lambda dn: len(dnTypes[dn.ontologyNode]), reverse=False) @@ -2591,74 +2593,74 @@ def flatten(dns): newDnsRoots = allDns #newDnsRoots = sorted(newDnsRoots, key=lambda dn: incomingLinks[dn], reverse=True) newDnsRoots = sorted(newDnsRoots, key=lambda dn: len(dnTypes[dn.ontologyNode]), reverse=False) - - # Set the updated root list + + # Set the updated root list if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Updated elements in the root dataNodes list - %s'%(newDnsRoots)) - dict.__setitem__(self, 'dataNode', newDnsRoots) # Updated the dict - + dict.__setitem__(self, 'dataNode', newDnsRoots) # Updated the dict + return - + def __buildRelationLink(self, vInfo, conceptInfo, keyDataName): """ Build or update relation dataNode in the data graph for a given key. - + Args: vInfo (object): Holds information about the value (e.g., tensor details). conceptInfo (dict): Information about the concept the dataNode represents. keyDataName (str): The key name for the attribute in question. - + Returns: None: The method updates the data graph in-place. """ relationName = conceptInfo['concept'].name - + # Check if data graph started existingRootDns = dict.__getitem__(self, 'dataNode') # DataNodes roots - + if not existingRootDns: _DataNodeBuilder__Logger.error('No dataNode created yet - abandon processing relation link dataNode value for %s and attribute %s'%(relationName,keyDataName)) return # No graph yet - information about relation should not be provided yet - + # Find if DataNodes for this relation have been created existingDnsForRelation = self.findDataNodesInBuilder(select = relationName) - + existingDnsForRelationNotSorted = OrderedDict() for dn in existingDnsForRelation: existingDnsForRelationNotSorted[dn.getInstanceID()] = dn - + existingDnsForRelationSorted = OrderedDict(sorted(existingDnsForRelationNotSorted.items())) - + # This is an information about relation attributes if conceptInfo['relationAttrData']: index = keyDataName.index('.') attrName = keyDataName[0:index] - + relationAttrsCacheName = conceptInfo['concept'].name + "RelationAttrsCache" - + if not dict.__contains__(self, relationAttrsCacheName): dict.__setitem__(self, relationAttrsCacheName, {}) - + relationAttrsCache = dict.__getitem__(self, relationAttrsCacheName) relationAttrsCache[attrName] = vInfo.value - + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Caching received data for %s related to relation %s dataNode, found %i existing dataNode of this type - provided value has length %i' %(keyDataName,relationName,len(existingDnsForRelation),vInfo.len)) - + # Find if all the needed attribute were initialized allAttrInit = True for relationAttributeName, _ in conceptInfo['relationAttrsGraph'].items(): if relationAttributeName not in relationAttrsCache: allAttrInit = False break - + if allAttrInit: #Create links for the relation DataNode # Find DataNodes connected by this relation based on graph definition existingDnsForAttr = OrderedDict() # DataNodes for Attributes of the relation for relationAttributeName, relationAttributeConcept in conceptInfo['relationAttrsGraph'].items(): _existingDnsForAttr = self.findDataNodesInBuilder(select = relationAttributeConcept.name) - + if _existingDnsForAttr: existingDnsForAttr[relationAttributeName] = _existingDnsForAttr if not getProductionModeStatus(): @@ -2666,28 +2668,28 @@ def __buildRelationLink(self, vInfo, conceptInfo, keyDataName): else: existingDnsForAttr[relationAttributeName] = [] _DataNodeBuilder__Logger.warning('Not found dataNodes of the attribute %s for concept %s'%(relationAttributeName,relationAttributeConcept.name)) - + attributeNames = [*existingDnsForAttr] - + # Create links between this relation and instance dataNode based on the candidate information provided by sensor for each relation attribute for relationDnIndex, relationDn in existingDnsForRelationSorted.items(): for attributeIndex, attribute in enumerate(attributeNames): candidatesForRelation = relationAttrsCache[attribute][relationDnIndex] - + for candidateIndex, candidate in enumerate(candidatesForRelation): isInRelation = candidate.item() if isInRelation == 0: continue - + candidateDn = existingDnsForAttr[attribute][candidateIndex] - + #if attributeIndex == 0: # candidateDn.addRelationLink(attribute, relationDn) - - relationDn.addRelationLink(attribute, candidateDn) + + relationDn.addRelationLink(attribute, candidateDn) if (not self.skeletonDataNode): relationDn.attributes[keyDataName] = vInfo.value[relationDnIndex] # Add / /Update value of the attribute - + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Create links between the relation %s and instance dataNode of types'%(conceptInfo['concept'].name)) else: @@ -2697,24 +2699,24 @@ def __buildRelationLink(self, vInfo, conceptInfo, keyDataName): _DataNodeBuilder__Logger.info('Updating attribute %s in relation link dataNodes %s'%(keyDataName,conceptInfo['concept'].name)) else: _DataNodeBuilder__Logger.info('Adding attribute %s to relation link dataNodes %s'%(keyDataName,conceptInfo['concept'].name)) - + if (not self.skeletonDataNode): for i, rDn in existingDnsForRelationSorted.items(): # Loop through all relation links dataNodes rDn.attributes[keyDataName] = vInfo.value[i] # Add / /Update value of the attribute self.__updateRootDataNodeList(list(existingDnsForRelationSorted.values())) - else: + else: # -- DataNode with this relation already created - update it with new attribute value if not getProductionModeStatus(): if keyDataName in self: _DataNodeBuilder__Logger.info('Updating attribute %s in relation link dataNodes %s'%(keyDataName,conceptInfo['concept'].name)) else: _DataNodeBuilder__Logger.info('Adding attribute %s to relation link dataNodes %s'%(keyDataName,conceptInfo['concept'].name)) - + if len(existingDnsForRelation) != vInfo.len: _DataNodeBuilder__Logger.error('Number of relations is %i and is different then the length of the provided tensor %i'%(len(existingDnsForRelation),vInfo.len)) raise ValueError('Number of relations is %i and is different then the length of the provided tensor %i'%(len(existingDnsForRelation),vInfo.len)) - + if (not self.skeletonDataNode): if len(existingDnsForRelationSorted) == 1: if vInfo.dim == 0: @@ -2728,35 +2730,35 @@ def __buildRelationLink(self, vInfo, conceptInfo, keyDataName): def __createInitialDataNode(self, vInfo, conceptInfo, keyDataName): """ Create initial data nodes for the data graph. - + Args: vInfo (object): Contains information about the value, like its length. conceptInfo (dict): Information about the concept associated with the data node. keyDataName (str): The name of the key for which the data node is being created. - + Returns: list: A list of created DataNode objects. """ conceptName = conceptInfo['concept'].name dns = [] - - if not getProductionModeStatus(): + + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Creating initial dataNode - provided value has length %i'%(vInfo.len)) if vInfo.len == 1: # Will use "READER" key as an id of the root dataNode instanceValue = "" - + if "READER" in self: instanceID = dict.__getitem__(self, "READER") else: instanceID = 0 - + initialDn = DataNode(myBuilder = self, instanceID = instanceID, instanceValue = instanceValue, ontologyNode = conceptInfo['concept']) - + if (not self.skeletonDataNode): initialDn.attributes[keyDataName] = vInfo.value - + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Created single dataNode with id %s of type %s'%(instanceID,conceptName)) dns.append(initialDn) @@ -2765,28 +2767,28 @@ def __createInitialDataNode(self, vInfo, conceptInfo, keyDataName): instanceValue = "" instanceID = vIndex newInitialDn = DataNode(myBuilder = self, instanceID = instanceID, instanceValue = instanceValue, ontologyNode = conceptInfo['concept']) - + if (not self.skeletonDataNode): newInitialDn.attributes[keyDataName] = v - + dns.append(newInitialDn) - + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Created %i dataNodes of type %s'%(len(dns),conceptName)) - + self.__updateRootDataNodeList(dns) - + return dns - + def __createSingleDataNode(self, vInfo, conceptInfo, keyDataName): """ Create initial data nodes for the data graph. - + Args: vInfo (object): Contains information about the value, like its length. conceptInfo (dict): Information about the concept associated with the data node. keyDataName (str): The name of the key for which the data node is being created. - + Returns: list: A list of created DataNode objects. """ @@ -2794,45 +2796,45 @@ def __createSingleDataNode(self, vInfo, conceptInfo, keyDataName): if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Received information about dataNodes of type %s - value dim is %i and length is %i'%(conceptName,vInfo.dim,vInfo.len)) - # -- Create a single the new dataNode + # -- Create a single the new dataNode instanceValue = "" instanceID = 0 newSingleDn = DataNode(myBuilder = self, instanceID = instanceID, instanceValue = instanceValue, ontologyNode = conceptInfo['concept']) if (not self.skeletonDataNode): newSingleDn.attributes[keyDataName] = vInfo.value - - if not getProductionModeStatus(): + + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Single new dataNode %s created'%(newSingleDn)) self.__updateRootDataNodeList(newSingleDn) - + return [newSingleDn] - + def __createMultiplyDataNode(self, vInfo, conceptInfo, keyDataName): """ Create multiple data nodes based on various conditions. - + Args: vInfo (object): Information about the value, like its dimension and length. conceptInfo (dict): Information about the concept associated with the data nodes. keyDataName (str): The name of the key for which the data node is being created. - + Returns: list: A list of the created DataNode objects. """ conceptName = conceptInfo['concept'].name - - # Master List of lists of created dataNodes - each list in the master list represent set of new dataNodes connected to the same parent dataNode + + # Master List of lists of created dataNodes - each list in the master list represent set of new dataNodes connected to the same parent dataNode # (identified by the index in the master list) - dns = [] - + dns = [] + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Received information about dataNodes of type %s - value dim is %i and length is %i'%(conceptName,vInfo.dim,vInfo.len)) # --- Create dataNodes - + # Check the type of sensor data - if vInfo.dim == 0: + if vInfo.dim == 0: _DataNodeBuilder__Logger.warning('Provided value is empty %s - abandon the update'%(vInfo.value)) return elif vInfo.dim == 1: # List with indexes for new DataNodes and data for attribute @@ -2842,20 +2844,20 @@ def __createMultiplyDataNode(self, vInfo, conceptInfo, keyDataName): for vIndex, v in enumerate(vInfo.value): instanceValue = "" instanceID = vIndex - + # Create new DataNode newDn = DataNode(myBuilder = self, instanceID = instanceID, instanceValue = instanceValue, ontologyNode = conceptInfo['concept']) - + # add attribute if (not self.skeletonDataNode): newDn.attributes[keyDataName] = v - - dns.append(newDn) + + dns.append(newDn) elif vInfo.dim == 2: # Two dimensional relation information if "relationMode" in conceptInfo: relatedDnsType = conceptInfo["relationAttrs"]['src'] relatedDns = self.findDataNodesInBuilder(select = relatedDnsType) - + if len(vInfo.value) > 0: try: requiredLenOFRelatedDns = len(vInfo.value[0]) @@ -2863,15 +2865,15 @@ def __createMultiplyDataNode(self, vInfo, conceptInfo, keyDataName): requiredLenOFRelatedDns = 0 else: requiredLenOFRelatedDns = 0 - + if requiredLenOFRelatedDns != len(relatedDns): _DataNodeBuilder__Logger.warning('Value of %s expects %i related dataNode of type %s but the number of existing dataNodes is %i - abandon the update' %(conceptInfo['relationName'],requiredLenOFRelatedDns,relatedDnsType,len(relatedDns))) return - + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Create %i new dataNodes of type %s'%(vInfo.len,conceptName)) - + if not conceptInfo['relation']: _DataNodeBuilder__Logger.info('It is a contain update of type - %s'%(conceptInfo["relationMode"])) if conceptInfo["relationMode"] == "forward": @@ -2883,21 +2885,21 @@ def __createMultiplyDataNode(self, vInfo, conceptInfo, keyDataName): instanceValue = "" instanceID = i newDn = DataNode(myBuilder = self, instanceID = instanceID, instanceValue = instanceValue, ontologyNode = conceptInfo['concept']) - + if (not self.skeletonDataNode): newDn.attributes[keyDataName] = vInfo.value[i] dns.append(newDn) - + # If it is not a regular relation but (Create contain relation between the new DataNode and existing DataNodes if not conceptInfo['relation']: if conceptInfo["relationMode"] == "forward": for index, isRelated in enumerate(vInfo.value[i]): if isRelated == 1: - relatedDns[index].addChildDataNode(newDn) + relatedDns[index].addChildDataNode(newDn) elif conceptInfo["relationMode"] == "backward": for index, isRelated in enumerate(vInfo.value[i]): if isRelated == 1: - newDn.addChildDataNode(relatedDns[index]) + newDn.addChildDataNode(relatedDns[index]) else: if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Create %i new dataNodes of type %s'%(vInfo.len,conceptName)) @@ -2905,43 +2907,43 @@ def __createMultiplyDataNode(self, vInfo, conceptInfo, keyDataName): instanceValue = "" instanceID = i newDn = DataNode(myBuilder = self, instanceID = instanceID, instanceValue = instanceValue, ontologyNode = conceptInfo['concept']) - + dns.append(newDn) else: _DataNodeBuilder__Logger.warning('It is an unsupported sensor input - %s'%(vInfo)) - - self.__updateRootDataNodeList(dns) + + self.__updateRootDataNodeList(dns) return dns - + def __updateDataNodes(self, vInfo, conceptInfo, keyDataName): """ Update existing data nodes based on various conditions. - + Notes: - This function is not called when `skeletonDataNode` is on. - + Args: vInfo (object): Information about the value, like its dimension and length. conceptInfo (dict): Information about the concept associated with the data nodes. keyDataName (str): The name of the key for which the data node is being updated. - + """ conceptName = conceptInfo['concept'].name existingDnsForConcept = self.findDataNodesInBuilder(select = conceptName) # Try to get DataNodes of the current concept if not existingDnsForConcept: existingDnsForConcept = self.findDataNodesInBuilder(select = conceptName) - + if not existingDnsForConcept: return - - if not getProductionModeStatus(): + + if not getProductionModeStatus(): if keyDataName in existingDnsForConcept[0].attributes: _DataNodeBuilder__Logger.info('Updating attribute %s in existing dataNodes - found %i dataNodes of type %s'%(keyDataName, len(existingDnsForConcept),conceptName)) else: _DataNodeBuilder__Logger.info('Adding attribute %s in existing dataNodes - found %i dataNodes of type %s'%(keyDataName, len(existingDnsForConcept),conceptName)) - - if len(existingDnsForConcept) > vInfo.len: # Not enough elements in the value + + if len(existingDnsForConcept) > vInfo.len: # Not enough elements in the value _DataNodeBuilder__Logger.warning('Provided value has length %i but found %i existing dataNode - abandon the update'%(vInfo.len,len(existingDnsForConcept))) elif len(existingDnsForConcept) == vInfo.len: # Number of value elements matches the number of found dataNodes if len(existingDnsForConcept) == 0: @@ -2961,14 +2963,14 @@ def __updateDataNodes(self, vInfo, conceptInfo, keyDataName): else: _DataNodeBuilder__Logger.error('Element %i in the list is not a dataNode - skipping it'%(vIndex)) raise ValueError('Element %i in the list is not a dataNode - skipping it'%(vIndex)) - + if keyDataName[0] == '<' and keyDataName[-1] == '>': if "contains" in existingDnsForConcept[0].impactLinks: dnParent = existingDnsForConcept[0].impactLinks["contains"][0] dnParent.attributes[keyDataName] = vInfo.value elif len(existingDnsForConcept) < vInfo.len: # Too many elements in the value _DataNodeBuilder__Logger.warning('Provided value has length %i but found %i existing dataNode - abandon the update'%(vInfo.len,len(existingDnsForConcept))) - + # Check if this is the contain relation update or attribute update if "relationMode" in conceptInfo and not conceptInfo["relation"]: relatedDnsType = conceptInfo["relationAttrs"]['src'] @@ -2979,81 +2981,81 @@ def __updateDataNodes(self, vInfo, conceptInfo, keyDataName): requiredLenOFRelatedDns = len(vInfo.value[0]) else: requiredLenOFRelatedDns = len(vInfo.item()) - + if requiredLenOFRelatedDns != len(relatedDns): _DataNodeBuilder__Logger.error('Provided value expected %i related dataNode of type %s but the number of existing dataNodes is %i - abandon the update' %(requiredLenOFRelatedDns,relatedDnsType,len(relatedDns))) raise ValueError('Provided value expected %i related dataNode of type %s but the number of existing dataNodes is %i - abandon the update' %(requiredLenOFRelatedDns,relatedDnsType,len(relatedDns))) - - if not getProductionModeStatus(): + + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('It is a contain update of type - %s'%(conceptInfo["relationMode"])) if conceptInfo["relationMode"] == "forward": _DataNodeBuilder__Logger.info('%s is contain in %s'%(conceptName, relatedDnsType)) else: _DataNodeBuilder__Logger.info('%s is contain in %s'%(relatedDnsType, conceptName)) - + for i in range(0,vInfo.len): exitingDn = existingDnsForConcept[i] - + if conceptInfo["relationMode"] == "forward": for index, isRelated in enumerate(vInfo.value[i]): if isRelated == 1: - relatedDns[index].addChildDataNode(exitingDn) + relatedDns[index].addChildDataNode(exitingDn) elif conceptInfo["relationMode"] == "backward": for index, isRelated in enumerate(vInfo.value[i]): if isRelated == 1: - exitingDn.addChildDataNode(relatedDns[index]) - - self.__updateRootDataNodeList(existingDnsForConcept) - + exitingDn.addChildDataNode(relatedDns[index]) + + self.__updateRootDataNodeList(existingDnsForConcept) + def __buildDataNode(self, vInfo, conceptInfo, keyDataName): """ Build or update a data node in the data graph for a given relationAttributeConcept. - + Notes: - This function will either create initial data nodes, create single data nodes, create multiple data nodes, or update existing data nodes based on various conditions. - + Args: vInfo (object): Information about the value, like its dimension and length. conceptInfo (dict): Information about the concept associated with the data nodes. keyDataName (str): The name of the key for which the data node is being updated or created. - + Returns: object: Newly created or updated data nodes. """ conceptName = conceptInfo['concept'].name - + if not dict.__contains__(self, 'dataNode'): # ------ No DataNode yet return self.__createInitialDataNode(vInfo, conceptInfo, keyDataName) # Done - End the method else: # ---------- DataNodes already created existingDnsForConcept = self.findDataNodesInBuilder(select = conceptName) # Try to get DataNodes of the current concept - - if len(existingDnsForConcept) == 0:# Check if DataNode for this concept already created + + if len(existingDnsForConcept) == 0:# Check if DataNode for this concept already created # No DataNode of this concept created yet - + # If attribute value is a single element - will create a single new DataNode - if vInfo.len == 1 and vInfo.dim < 2: + if vInfo.len == 1 and vInfo.dim < 2: return self.__createSingleDataNode(vInfo, conceptInfo, keyDataName) else: # -- Value is multiple elements return self.__createMultiplyDataNode(vInfo, conceptInfo, keyDataName) else: # DataNode with this concept already created - update it if (not self.skeletonDataNode): self.__updateDataNodes(vInfo, conceptInfo, keyDataName) - + def __addEquality(self, vInfo, conceptInfo, equalityConceptName, keyDataName): """ Add equality relations between existing data nodes of specified concepts based on the provided value information. - + Args: vInfo (object): Information about the value matrix that indicates equality, like its shape. conceptInfo (dict): Information about the concept associated with one set of data nodes. equalityConceptName (str): The name of the second concept associated with another set of data nodes. keyDataName (str): The name of the key for which the data node is being checked for equality. - + Notes: - Logging statements are used to indicate the progress and success of the equality addition. - Checks are made to ensure that data nodes exist for both specified concepts before proceeding. @@ -3062,31 +3064,31 @@ def __addEquality(self, vInfo, conceptInfo, equalityConceptName, keyDataName): conceptName = conceptInfo['concept'].name existingDnsForConcept = self.findDataNodesInBuilder(select = conceptName) existingDnsForEqualityConcept = self.findDataNodesInBuilder(select = equalityConceptName) - + if not existingDnsForConcept and not existingDnsForEqualityConcept: _DataNodeBuilder__Logger.warning('No datNodes created for concept %s and equality concept %s'%(conceptName,equalityConceptName)) return - + if not existingDnsForConcept: _DataNodeBuilder__Logger.warning('No datNodes created for concept %s'%(conceptName)) return - + if not existingDnsForEqualityConcept: _DataNodeBuilder__Logger.warning('No datNodes created for equality concept %s'%(equalityConceptName)) return - + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Added equality between dataNodes of types %s and %s'%(conceptName,equalityConceptName)) for conceptDn in existingDnsForConcept: for equalDn in existingDnsForEqualityConcept: - + if conceptDn.getInstanceID() >= vInfo.value.shape[0]: continue - + if equalDn.getInstanceID() >= vInfo.value.shape[1]: continue - + if vInfo.value[conceptDn.getInstanceID(), equalDn.getInstanceID()]: if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('DataNodes of %s is equal to %s'%(conceptDn,equalDn)) @@ -3095,21 +3097,21 @@ def __addEquality(self, vInfo, conceptInfo, equalityConceptName, keyDataName): def __processAttributeValue(self, value, keyDataName): """ Processes the attribute value to determine its structure and nature. - - This method analyzes the attribute value and categorizes it based on its dimensionality, - whether it is a scalar or a list, and its length. It returns a named tuple with this + + This method analyzes the attribute value and categorizes it based on its dimensionality, + whether it is a scalar or a list, and its length. It returns a named tuple with this information, which can be used for further processing or logging. - + Args: value (Union[torch.Tensor, list, scalar]): The value of the attribute to be processed. keyDataName (str): The name of the attribute for which the value is processed. - + Returns: namedtuple: A named tuple 'ValueInfo' with fields 'len', 'value', and 'dim' where - 'len' is the length of the first dimension of the value. - 'value' is the original or processed value. - 'dim' is the number of dimensions of the value. - + Notes: - Tensor or list with a length of 1 is considered as scalar. - It supports Tensor, list, and scalar data types. @@ -3124,13 +3126,13 @@ def __processAttributeValue(self, value, keyDataName): lenV = 1 else: lenV = len(value) - + if not isinstance(value, (torch.Tensor, list)): # It is scalar value - return ValueInfo(len = 1, value = value, dim=0) - + return ValueInfo(len = 1, value = value, dim=0) + if isinstance(value, torch.Tensor) and dimV == 0: # It is a Tensor but also scalar value return ValueInfo(len = 1, value = value.item(), dim=0) - + if (lenV == 1): # It is Tensor or list with length 1 - treat it as scalar if isinstance(value, list) and not isinstance(value[0], (torch.Tensor, list)) : # Unpack the value return ValueInfo(len = 1, value = value[0], dim=0) @@ -3138,12 +3140,12 @@ def __processAttributeValue(self, value, keyDataName): return ValueInfo(len = 1, value = torch.squeeze(value, 0), dim=0) # If it is Tensor or list with length 2 but it is for attribute providing probabilities - assume it is a scalar value - if isinstance(value, list) and lenV == 2 and keyDataName[0] == '<': + if isinstance(value, list) and lenV == 2 and keyDataName[0] == '<': return ValueInfo(lenV = 1, value = value, dim=0) elif isinstance(value, torch.Tensor) and lenV == 2 and dimV == 0 and keyDataName[0] == '<': return ValueInfo(len = 1, value = value, dim=0) - if isinstance(value, list): + if isinstance(value, list): if not isinstance(value[0], (torch.Tensor, list)) or (isinstance(value[0], torch.Tensor) and value[0].dim() == 0): return ValueInfo(len = lenV, value = value, dim=1) elif not isinstance(value[0][0], (torch.Tensor, list)) or (isinstance(value[0][0], torch.Tensor) and value[0][0].dim() == 0): @@ -3156,17 +3158,17 @@ def __processAttributeValue(self, value, keyDataName): elif isinstance(value, torch.Tensor): return ValueInfo(len = lenV, value = value, dim=dimV) - + def collectTime(self, start): """ Collects the time taken for the __setitem__ operation and stores it in internal lists. - + This method calculates the time elapsed for a __setitem__ operation and appends that, along with the start and end timestamps, to respective lists stored in the object. - + Args: start (int): The start time of the __setitem__ operation in nanoseconds. - + Notes: - The time taken for each __setitem__ operation is stored in a list named 'DataNodeTime'. - The start time for each __setitem__ operation is stored in a list named 'DataNodeTime_start'. @@ -3175,37 +3177,37 @@ def collectTime(self, start): # Collect time used for __setitem__ end = perf_counter_ns() currentTime = end - start - + timeList = self.setdefault("DataNodeTime", []) timeList.append(currentTime) startTimeList = self.setdefault("DataNodeTime_start", []) startTimeList.append(start) endTimeList = self.setdefault("DataNodeTime_end", []) endTimeList.append(end) - + def __setitem__(self, _key, value): """ Overloaded __setitem__ method for the DataNodeBuilder class. This method is responsible for adding or updating key-value pairs in the dictionary-like object. - + Parameters: ----------- _key : Sensor, Property, Concept, or str The key to insert into the dictionary. It can be an instance of Sensor, Property, Concept classes, or a string. value : any The value to associate with the key. It can be of any data type. - + Behavior: --------- - If `_key` is a Sensor and its `build` attribute is set to False, the value is directly inserted without further processing. - If `_key` is a Property, the value is directly inserted without further processing. - If `_key` is a Concept or a string containing a Concept, additional logic is invoked to update the associated graph and indices. - If the system is not in production mode, additional logging and checks are performed. - + Returns: -------- None - + Side Effects: ------------- - Updates the underlying dictionary. @@ -3217,7 +3219,7 @@ def __setitem__(self, _key, value): start = perf_counter_ns() self.__addSetitemCounter() - + if isinstance(_key, (Sensor, Property, Concept)): key = _key.fullname if isinstance(_key, Sensor) and not _key.build: @@ -3230,7 +3232,7 @@ def __setitem__(self, _key, value): self.collectTime(start) return dict.__setitem__(self, _key, value) - + if isinstance(_key, Property): if isinstance(value, torch.Tensor): _DataNodeBuilder__Logger.debug('No processing Property as key - key - %s, key type - %s, value - %s, shape %s'%(key,type(_key),type(value),value.shape)) @@ -3247,9 +3249,9 @@ def __setitem__(self, _key, value): _DataNodeBuilder__Logger.error('key - %s, type %s is not supported'%(_key,type(_key))) self.collectTime(start) return - + skey = key.split('/') - + # Check if the key with this value has been set recently # If not create a new sensor for it # If yes stop __setitem__ and return - the same value for the key was added last time that key was set @@ -3257,7 +3259,7 @@ def __setitem__(self, _key, value): self.myLoggerTime.info(f"DataNode Builder skipping repeated value for sensor - {skey}") self.collectTime(start) return # Stop __setitem__ for repeated key value combination - + if not getProductionModeStatus(): if isinstance(value, torch.Tensor): _DataNodeBuilder__Logger.info('key - %s, key type - %s, value - %s, shape %s'%(key,type(_key),type(value),value.shape)) @@ -3270,43 +3272,43 @@ def __setitem__(self, _key, value): _DataNodeBuilder__Logger.error('The value for the key %s is None - abandon the update'%(key)) self.collectTime(start) return dict.__setitem__(self, _key, value) - - if len(skey) < 2: + + if len(skey) < 2: _DataNodeBuilder__Logger.error('The key %s has only two elements, needs at least three - abandon the update'%(key)) self.collectTime(start) return dict.__setitem__(self, _key, value) - + usedGraph = dict.__getitem__(self, "graph") # Find if the key include concept from graph - + graphPathIndex = usedGraph.cutGraphName(skey) keyWithoutGraphName = skey[graphPathIndex:] - graphPath = ''.join(map(str, skey[:graphPathIndex])) - + graphPath = ''.join(map(str, skey[:graphPathIndex])) + # Check if found concept in the key if not keyWithoutGraphName: _DataNodeBuilder__Logger.warning('key - %s has not concept part - returning'%(key)) self.collectTime(start) return dict.__setitem__(self, _key, value) - + # Find description of the concept in the graph if isinstance(_key, Sensor): try: - conceptName = _key.concept.name + conceptName = _key.concept.name except TypeError as _: conceptName = keyWithoutGraphName[0] else: conceptName = keyWithoutGraphName[0] concept = self.__findConcept(conceptName, usedGraph) - + if not concept: _DataNodeBuilder__Logger.warning('conceptName - %s has not been found in the used graph %s - returning'%(conceptName,usedGraph.fullname)) self.collectTime(start) return dict.__setitem__(self, _key, value) - + conceptInfo = self.__findConceptInfo(usedGraph, concept) - + if isinstance(_key, Sensor): self.__updateConceptInfo(usedGraph, conceptInfo, _key) @@ -3319,20 +3321,20 @@ def __setitem__(self, _key, value): # Create key for DataNode construction keyDataName = "".join(map(lambda x: '/' + x, keyWithoutGraphName[1:-1])) keyDataName = keyDataName[1:] # __cut first '/' from the string - + if conceptInfo['label']: keyDataName += '/label' - + vInfo = self.__processAttributeValue(value, keyDataName) - + # Decide if this is equality between concept data, dataNode creation or update for concept or relation link if keyDataName.find("_Equality_") > 0: equalityConceptName = keyDataName[keyDataName.find("_Equality_") + len("_Equality_"):] self.__addEquality(vInfo, conceptInfo, equalityConceptName, keyDataName) - else: + else: _DataNodeBuilder__Logger.debug('%s found in the graph; it is a concept'%(conceptName)) index = self.__buildDataNode(vInfo, conceptInfo, keyDataName) # Build or update Data node - + if index: indexKey = graphPath + '/' + conceptName + '/index' dict.__setitem__(self, indexKey, index) @@ -3347,14 +3349,14 @@ def __setitem__(self, _key, value): allDns.update(index) except TypeError as ty: pass - + DataNodesConcepts[conceptName] = index #dict.__setitem__(self, "DataNodesConcepts", DataNodesConcepts) - + if conceptInfo['relation']: _DataNodeBuilder__Logger.debug('%s is a relation'%(conceptName)) self.__buildRelationLink(vInfo, conceptInfo, keyDataName) # Build or update relation link - + if self.skeletonDataNode: if conceptName in skey: # Find the index of "conceptName" in skey @@ -3362,7 +3364,7 @@ def __setitem__(self, _key, value): # Join "conceptName" with the next element in skey keyInRootDataNode = "/".join(skey[index:index+2]) - + # Add "/label" to the key if the concept has a label marked if conceptInfo['label']: keyInRootDataNode += "/label" @@ -3375,66 +3377,66 @@ def __setitem__(self, _key, value): else: # throw an exception raise Exception("The key does not contain conceptName") - + # Add key to the list of keys in order if self.skeletonDataNodeFull: KeysInOrder = dict.__getitem__(self, "KeysInOrder") KeysInOrder.append(_key) - + # Add value to the underling dictionary r = dict.__setitem__(self, _key, value) - + if not r: pass # Error when adding entry to dictionary ? - + self.collectTime(start) - return r - + return r + def __delitem__(self, key): """ Overloaded __delitem__ method for the DataNodeBuilder class. This method is responsible for deleting a key-value pair from the dictionary-like object. - + Parameters: ----------- key : any hashable type The key to be deleted from the dictionary. - + Returns: -------- None """ return dict.__delitem__(self, key) - + def __contains__(self, key): """ Overloaded __contains__ method for the DataNodeBuilder class. This method checks if the key is present in the dictionary-like object. - + Parameters: ----------- key : any hashable type The key to be checked for existence in the dictionary. - + Returns: -------- bool True if the key exists, False otherwise. """ return dict.__contains__(self, key) - + def __addGetDataNodeCounter(self): """ Method to increment a counter that keeps track of the number of times the __getitem__ method is called. - + Parameters: ----------- None - + Returns: -------- None - + Side Effects: ------------- - Updates the internal counter for __getitem__ calls. @@ -3445,50 +3447,50 @@ def __addGetDataNodeCounter(self): else: currentCounter = dict.__getitem__(self, counterName) dict.__setitem__(self, counterName, currentCounter + 1) - + def findDataNodesInBuilder(self, select=None, indexes=None): """ Method to find data nodes that meet certain criteria within the DataNodeBuilder object. - + Parameters: ----------- select : function or None, optional A function to apply to each DataNode to determine if it should be selected. Defaults to None. indexes : list or None, optional A list of indexes to specifically look for. Defaults to None. - + Returns: -------- list A list of DataNodes that meet the given criteria. """ existingRootDns = dict.__getitem__(self, 'dataNode') # DataNodes roots - + if not existingRootDns: foundDns = [] else: foundDns = existingRootDns[0].findDatanodes(dns=existingRootDns, select=select, indexes=indexes) - + return foundDns def createFullDataNode(self, rootDataNode): """ Method to create a full data node based on the current skeleton of the DataNodeBuilder object. - + Parameters: ----------- rootDataNode : DataNode object The root data node to which attributes will be added. - + Returns: -------- None - + Side Effects: ------------- - Modifies internal state to reflect that a full data node has been created. - Logs time taken to create the full data node. - + Notes: ------ - This method operates under the assumption that the DataNodeBuilder is initially in skeleton mode. @@ -3498,54 +3500,54 @@ def createFullDataNode(self, rootDataNode): startCreateFullDataNode = perf_counter() self.skeletonDataNodeFull = False # Set temporary flag to False to allow creation of full dataNode - + keysInOrder = dict.__getitem__(self, "KeysInOrder") - + for key in keysInOrder: # Run the original values through __setitem__ to build the full dataNode self.__setitem__(key, dict.__getitem__(self, key)) - + if self.skeletonDataNode: # Get the "allDns" set from the data node, or create a new empty set if it doesn't exist allDns = self.get("allDns", set()) - + # Iterate over the data nodes in "allDns" and add the "rootDataNode" attribute to them for dn in allDns: if dn == rootDataNode: continue dn.attributes["rootDataNode"] = rootDataNode - - self.skeletonDataNodeFull = True # Return flag to the original - + + self.skeletonDataNodeFull = True # Return flag to the original + endCreateFullDataNode = perf_counter() elapsedCreateFullDataNode = (endCreateFullDataNode - startCreateFullDataNode) * 1000 - self.myLoggerTime.info(f'Creating Full Datanode: {elapsedCreateFullDataNode}ms') - + self.myLoggerTime.info(f'Creating Full Datanode: {elapsedCreateFullDataNode}ms') + def createBatchRootDN(self): """ Creates a batch root DataNode when certain conditions are met. - + Conditions for creating a new batch root DataNode: - If the DataNodeBuilder object already has a single root DataNode, no new root DataNode will be created. - If the DataNodeBuilder object has DataNodes of different types, a batch root DataNode cannot be created. - + Parameters: ----------- None - + Returns: -------- None - + Side Effects: ------------- - Modifies the 'dataNode' attribute of the DataNodeBuilder object. - Logs messages based on the production mode status and whether a new root DataNode is created or not. - + Raises: ------- - ValueError: When the DataNodeBuilder object has no DataNodes, or existing DataNodes have no connected graph. - + Notes: ------ - This method makes use of internal logging for debugging and timing. @@ -3557,36 +3559,36 @@ def createBatchRootDN(self): if not getProductionModeStatus(): _DataNodeBuilder__Logger.info(f'No new Batch Root DataNode created - DataNode Builder already has single Root DataNode with id {rootDn.instanceID} of type {rootDn.getOntologyNode().name}') return - + # Check if there are more than one type of DataNodes in the builder typesInDNs = set() for i, d in enumerate(existingDns): typesInDNs.add(d.getOntologyNode().name) - + # If there are more than one type of DataNodes in the builder, then it is not possible to create new Batch Root DataNode if len(typesInDNs) > 1: _DataNodeBuilder__Logger.warn('DataNode Builder has DataNodes of different types: %s, not possible to create batch Datanode' % (typesInDNs)) return - + # Create the Batch Root DataNode supGraph = existingDns[1].getOntologyNode().sup if supGraph is None: - raise ValueError('Not able to create Batch Root DataNode - existing DataNodes in the Builder have concept type %s not connected to any graph: %s'%(typesInDNs)) + raise ValueError('Not able to create Batch Root DataNode - existing DataNodes in the Builder have concept type %s not connected to any graph: %s'%(typesInDNs)) batchRootDNValue = "" batchRootDNID = 0 - + if 'batch' in supGraph.concepts: batchRootDNOntologyNode = supGraph.concepts['batch' ] else: batchRootDNOntologyNode = Concept(name='batch') supGraph.attach(batchRootDNOntologyNode) - + batchRootDN = DataNode(myBuilder = self, instanceID = batchRootDNID, instanceValue = batchRootDNValue, ontologyNode = batchRootDNOntologyNode) - + for i, d in enumerate(existingDns): - batchRootDN.addChildDataNode(d) - + batchRootDN.addChildDataNode(d) + # The new Root DataNode it the batch Root DataNode self.__updateRootDataNodeList([batchRootDN]) @@ -3595,39 +3597,39 @@ def createBatchRootDN(self): self.myLoggerTime.info('Created single Batch Root DataNode with id %s of type %s'%(batchRootDNID,batchRootDNOntologyNode)) else: raise ValueError('DataNode Builder has no DataNode started yet') - + def getDataNode(self, context="interference", device='auto'): """ Retrieves and returns the first DataNode from the DataNodeBuilder object based on the given context and device. - + Parameters: ----------- context : str, optional The context under which to get the DataNode, defaults to "interference". device : str, optional The torch device to set for the DataNode, defaults to 'auto'. - + Returns: -------- DataNode or None Returns the first DataNode if it exists, otherwise returns None. - + Side Effects: ------------- - Updates the torch device for the returned DataNode based on the 'device' parameter. - Logs various messages based on the context and production mode. - + Raises: ------- None - + Notes: ------ - This method makes use of internal logging for debugging and timing. - + """ self.__addGetDataNodeCounter() - + if context=="interference": if self.skeletonDataNode: self.myLoggerTime.info("DataNode Builder is using skeleton datanode mode") @@ -3637,7 +3639,7 @@ def getDataNode(self, context="interference", device='auto'): # self['DataNodeTime'] is in nanoseconds, so divide by 1000000 to get milliseconds elapsedInMsDataNodeBuilder = sum(self['DataNodeTime'])/1000000 self.myLoggerTime.info(f"DataNode Builder time usage - {elapsedInMsDataNodeBuilder:.5f}ms") - + #self.myLoggerTime.info(f"DataNode Builder elapsed time in ns - {self['DataNodeTime']}") #self.myLoggerTime.info(f"DataNode Builder start time in ns - {self['DataNodeTime_start']}") #self.myLoggerTime.info(f"DataNode Builder end time in ns - {self['DataNodeTime_end']}") @@ -3645,17 +3647,17 @@ def getDataNode(self, context="interference", device='auto'): # If DataNode it created then return it if dict.__contains__(self, 'dataNode'): existingDns = dict.__getitem__(self, 'dataNode') - + if len(existingDns) != 0: returnDn = existingDns[0] - + # Set the torch device returnDn.current_device = device if returnDn.current_device == 'auto': # if not set use cpu or cuda if available returnDn.current_device = 'cpu' if torch.cuda.is_available(): returnDn.current_device = 'cuda' - + if len(existingDns) != 1: typesInDNs = {d.getOntologyNode().name for d in existingDns[1:]} _DataNodeBuilder__Logger.warning(f'Returning first dataNode with id {returnDn.instanceID} of type {returnDn.getOntologyNode().name} - there are total {len(existingDns)} dataNodes of types {typesInDNs}') @@ -3664,21 +3666,21 @@ def getDataNode(self, context="interference", device='auto'): if not getProductionModeStatus(): _DataNodeBuilder__Logger.info(f'Returning dataNode with id {returnDn.instanceID} of type {returnDn.getOntologyNode().name}') self.myLoggerTime.info(f'Returning dataNode with id {returnDn.instanceID} of type {returnDn.getOntologyNode().name}') - + if self.skeletonDataNode: # Get the "variableSet" dictionary from the data node, or create a new empty dictionary if it doesn't exist variableSet = self.get("variableSet", {}) # Create a dictionary of the items in "variableSet" with the keys and values swapped variableSetDict = {k2: self[k1] for k1, k2 in dict(variableSet).items()} - + # Add the "variableSet" dictionary to the return data node attributes returnDn.attributes["variableSet"] = variableSetDict # Get the "propertySet" dictionary from the data node, or create a new empty dictionary if it doesn't exist propertySet = self.get("propertySet", {}) - # Create a dictionary of the items in "propertySet" + # Create a dictionary of the items in "propertySet" propertySetDict = {k2: self[k1] for k1, k2 in dict(propertySet).items()} # Add the "propertySet" dictionary to the return data node attributes @@ -3692,51 +3694,51 @@ def getDataNode(self, context="interference", device='auto'): if dn == returnDn: continue dn.attributes["rootDataNode"] = returnDn - + return returnDn - + _DataNodeBuilder__Logger.error('Returning None - there are no dataNode') return None - + def getBatchDataNodes(self): """ Retrieves and returns all DataNodes stored in the DataNodeBuilder object. - + Returns: -------- list or None Returns a list of all existing DataNodes if they exist; otherwise returns None. - + Side Effects: ------------- - Logs various messages about the internal state and time usage of the DataNodeBuilder object. - + Raises: ------- None - + Notes: ------ - This method makes use of internal logging for debugging and timing. """ self.__addGetDataNodeCounter() - + if 'Counter' + '_setitem' in self: self.myLoggerTime.info("DataNode Builder the set method called - %i times"%(self['Counter' + '_setitem' ])) if 'DataNodeTime' in self: # self['DataNodeTime'] is in nanoseconds, so divide by 1000000 to get milliseconds elapsedInMsDataNodeBuilder = sum(self['DataNodeTime'])/1000000 self.myLoggerTime.info(f"DataNode Builder time usage - {elapsedInMsDataNodeBuilder:.5f}ms") - + if dict.__contains__(self, 'dataNode'): existingDns = dict.__getitem__(self, 'dataNode') - - if len(existingDns) > 0: - + + if len(existingDns) > 0: + if not getProductionModeStatus(): _DataNodeBuilder__Logger.info('Returning %i dataNodes - %s'%(len(existingDns),existingDns)) return existingDns - + _DataNodeBuilder__Logger.error('Returning None - there are no dataNodes') return None diff --git a/domiknows/graph/logicalConstrain.py b/domiknows/graph/logicalConstrain.py index f3d933d7..346195f7 100644 --- a/domiknows/graph/logicalConstrain.py +++ b/domiknows/graph/logicalConstrain.py @@ -376,33 +376,40 @@ def createILPCount(self, model, myIlpBooleanProcessor, v, headConstrain, cOperat lcVariableSet0 = v[lcVariableName0] zVars = [] # Output ILP variables - - for i, _ in enumerate(lcVariableSet0): - varsSetup = [] + # for i, _ in enumerate(lcVariableSet0): + # varsSetup = [] + # + # var = [] + # for currentV in iter(v): + # var.extend(v[currentV][i]) + # + # if len(var) == 0: + # if not (headConstrain or integrate): + # zVars.append([None]) + # + # continue + # + # if headConstrain or integrate: + # varsSetup.extend(var) + # else: + # varsSetup.append(var) + varsSetup = [] - var = [] - for currentV in iter(v): - var.extend(v[currentV][i]) - - if len(var) == 0: - if not (headConstrain or integrate): - zVars.append([None]) - - continue - - if headConstrain or integrate: - varsSetup.extend(var) - else: - varsSetup.append(var) - - # -- Use ILP variable setup to create constrains - if headConstrain or integrate: - zVars.append([myIlpBooleanProcessor.countVar(model, *varsSetup, onlyConstrains = headConstrain, limitOp = cOperation, limit=cLimit, - logicMethodName = logicMethodName)]) - else: - for current_var in varsSetup: - zVars.append([myIlpBooleanProcessor.countVar(model, *current_var, onlyConstrains = headConstrain, limitOp = cOperation, limit=cLimit, - logicMethodName = logicMethodName)]) + var = [currentV[0] for currentV in iter(lcVariableSet0)] + + if headConstrain or integrate: + varsSetup.extend(var) + else: + varsSetup.append(var) + + # -- Use ILP variable setup to create constrains + if headConstrain or integrate: + zVars.append([myIlpBooleanProcessor.countVar(model, *varsSetup, onlyConstrains = headConstrain, limitOp = cOperation, limit=cLimit, + logicMethodName = logicMethodName)]) + else: + for current_var in varsSetup: + zVars.append([myIlpBooleanProcessor.countVar(model, *current_var, onlyConstrains = headConstrain, limitOp = cOperation, limit=cLimit, + logicMethodName = logicMethodName)]) if model is not None: model.update() diff --git a/domiknows/program/model/lossModel.py b/domiknows/program/model/lossModel.py index f1a30d75..09cfca25 100644 --- a/domiknows/program/model/lossModel.py +++ b/domiknows/program/model/lossModel.py @@ -12,7 +12,8 @@ class LossModel(torch.nn.Module): logger = logging.getLogger(__name__) def __init__(self, graph, - tnorm='P', + tnorm='P', + counting_tnorm=None, sample = False, sampleSize = 0, sampleGlobalLoss = False, device='auto'): """ This function initializes a LossModel object with the given parameters and sets up the @@ -42,6 +43,7 @@ def __init__(self, graph, self.build = True self.tnorm = tnorm + self.counting_tnorm = counting_tnorm self.device = device self.sample = sample @@ -102,7 +104,7 @@ def forward(self, builder, build=None): datanode = builder.getDataNode(device=self.device) # Call the loss calculation returns a dictionary, keys are matching the constraints - constr_loss = datanode.calculateLcLoss(tnorm=self.tnorm, sample=self.sample, sampleSize = self.sampleSize) + constr_loss = datanode.calculateLcLoss(tnorm=self.tnorm,counting_tnorm=self.counting_tnorm, sample=self.sample, sampleSize = self.sampleSize) lmbd_loss = [] if self.sampleGlobalLoss and constr_loss['globalLoss']: @@ -131,7 +133,7 @@ def forward(self, builder, build=None): class PrimalDualModel(LossModel): logger = logging.getLogger(__name__) - def __init__(self, graph, tnorm='P', device='auto'): + def __init__(self, graph, tnorm='P',counting_tnorm=None, device='auto'): """ The above function is the constructor for a class that initializes an object with a graph, tnorm, and device parameters. @@ -147,7 +149,7 @@ def __init__(self, graph, tnorm='P', device='auto'): :param device: The `device` parameter specifies the device on which the computations will be performed. It can take the following values:, defaults to auto (optional) """ - super().__init__(graph, tnorm=tnorm, device=device) + super().__init__(graph, tnorm=tnorm, counting_tnorm = counting_tnorm, device=device) class SampleLossModel(torch.nn.Module): logger = logging.getLogger(__name__) @@ -277,7 +279,6 @@ def forward(self, builder, build=None): if not replace_mul: loss_value = true_val.sum() / lossTensor.sum() loss_value = epsilon - ( -1 * torch.log(loss_value) ) - # loss_value = -1 * torch.log(loss_value) if self.iter_step < self.warmpup: with torch.no_grad(): min_val = loss_value diff --git a/domiknows/solver/gurobiILPOntSolver.py b/domiknows/solver/gurobiILPOntSolver.py index e9566b76..68d6fc37 100644 --- a/domiknows/solver/gurobiILPOntSolver.py +++ b/domiknows/solver/gurobiILPOntSolver.py @@ -1762,7 +1762,7 @@ def generateSemanticSample(self, rootDn, conceptsRelations): return productSize # -- Calculated loss values for logical constraints - def calculateLcLoss(self, dn, tnorm='L', sample = False, sampleSize = 0, sampleGlobalLoss = False, conceptsRelations = None): + def calculateLcLoss(self, dn, tnorm='L',counting_tnorm=None, sample = False, sampleSize = 0, sampleGlobalLoss = False, conceptsRelations = None): start = perf_counter() m = None @@ -1785,6 +1785,8 @@ def calculateLcLoss(self, dn, tnorm='L', sample = False, sampleSize = 0, sampleG else: myBooleanMethods = self.myLcLossBooleanMethods self.myLcLossBooleanMethods.setTNorm(tnorm) + if counting_tnorm: + self.myLcLossBooleanMethods.setCountingTNorm(counting_tnorm) self.myLogger.info('Calculating loss ') self.myLoggerTime.info('Calculating loss ') diff --git a/domiknows/solver/lcLossBooleanMethods.py b/domiknows/solver/lcLossBooleanMethods.py index 230346be..3b93e3a3 100644 --- a/domiknows/solver/lcLossBooleanMethods.py +++ b/domiknows/solver/lcLossBooleanMethods.py @@ -11,6 +11,7 @@ class lcLossBooleanMethods(ilpBooleanProcessor): def __init__(self, _ildConfig = ilpConfig) -> None: super().__init__() self.tnorm = 'P' + self.counting_tnorm = None self.grad = True self.myLogger = logging.getLogger(ilpConfig['log_name']) @@ -27,6 +28,22 @@ def setTNorm(self, tnorm='L'): raise Exception('Unknown type of t-norms formulation - %s'%(tnorm)) self.tnorm = tnorm + + def setCountingTNorm(self, tnorm='L'): + if tnorm =='L': + if self.ifLog: self.myLogger.info("Using Lukasiewicz t-norms Formulation") + elif tnorm =='G': + if self.ifLog: self.myLogger.info("Using Godel t-norms Formulation") + elif tnorm =='P': + if self.ifLog: self.myLogger.info("Using Product t-norms Formulation") + elif tnorm =='SP': + if self.ifLog: self.myLogger.info("Using Simplified Product t-norms Formulation") + #elif tnorm =='LSE': + # if self.ifLog: self.myLogger.info("Using Log Sum Exp Formulation") + else: + raise Exception('Unknown type of t-norms formulation - %s'%(tnorm)) + + self.counting_tnorm = tnorm def _isTensor(self, v): if v is None: @@ -366,10 +383,54 @@ def epqVar(self, _, var1, var2, onlyConstrains = False): return epqLoss else: return epqSuccess - + + def calc_probabilities(self, t, s): + + n = len(t) + dp = torch.zeros(s + 1, device=self.current_device, dtype=torch.float64) + dp[0] = 1.0 + dp.requires_grad_() + for i in range(n): + dp_new = dp.clone() + dp_new[1:min(s, i + 1) + 1] = dp[1:min(s, i + 1) + 1] * (1 - t[i]) + dp[:min(s, i + 1)] * t[i] + dp_new[0] = dp[0] * (1 - t[i]) + dp = dp_new + return dp + def countVar(self, _, *var, onlyConstrains = False, limitOp = '==', limit = 1, logicMethodName = "COUNT"): logicMethodName = "COUNT" - + + method=self.counting_tnorm if self.counting_tnorm else self.tnorm + #if method=="LSE": # log sum exp + # exists_at_least_one = lambda t, beta=100.0: torch.clamp(-torch.log((1 / beta) * torch.log(torch.sum(torch.exp(beta * t)))), min=0,max=1) + # exists_at_least_s = lambda t, s, beta=10.0: torch.clamp(torch.relu(s - torch.sum(torch.sigmoid(beta * (t - 0.5)))),max=1) + # exists_at_most_s = lambda t, s, beta=10.0: torch.clamp(torch.relu(torch.sum(torch.sigmoid(beta * (t - 0.5))) - s),max=1) + # exists_exactly_s = lambda t, s, beta=10.0: torch.clamp(torch.abs(s - torch.sum(torch.sigmoid(beta * (t - 0.5)))),max=1) + if method=="G": # Godel logic + exists_at_least_one = lambda t: 1 - torch.max(t) + exists_at_least_s = lambda t, s: 1- torch.min(torch.sort(t, descending=True)[0][:s]) + exists_at_most_s = lambda t, s: 1 - torch.min(torch.sort(1 - t, descending=True)[0][:len(t)-s]) + exists_exactly_s = lambda t, s: 1 - torch.min(torch.min(torch.sort(t, descending=True)[0][:s]) , torch.min(torch.sort(1 - t, descending=True)[0][:len(t)-s])) + elif method == "L": # Ɓukasiewicz logic + exists_at_least_one = lambda t: 1 - torch.min(torch.sum(t), torch.ones(1, device=self.current_device, requires_grad=True, dtype=torch.float64)) + exists_at_least_s = lambda t, s: 1 - torch.max(torch.sum(torch.sort(t, descending=True)[0][:s])-(s-1), torch.zeros(1, device=self.current_device, requires_grad=True, dtype=torch.float64)) + exists_at_most_s = lambda t, s: 1 - torch.max(torch.sum(torch.sort(1 - t, descending=True)[0][:len(t)-s])-(len(t)-s-1), torch.zeros(1, device=self.current_device, requires_grad=True, dtype=torch.float64)) + exists_exactly_s = lambda t, s: 1 - torch.max(torch.sum(torch.sort(t, descending=True)[0][:s])-(s-1)+torch.sum(torch.sort(1 - t, descending=True)[0][:len(t)-s])-(len(t)-s-1), torch.zeros(1, device=self.current_device, requires_grad=True, dtype=torch.float64)) + + elif method == "P": # Product logic + + exists_at_least_one = lambda t: torch.prod(1 - t) + exists_at_least_s = lambda t, s: 1 - torch.sum(self.calc_probabilities(t, len(t))[s:]) + exists_at_most_s = lambda t, s: 1 - torch.sum(self.calc_probabilities(t, s)) + exists_exactly_s = lambda t, s: 1 - self.calc_probabilities(t, s)[s] + + else: # "SP" # Simplified product logic + exists_at_least_one = lambda t: torch.prod(1 - t) + exists_at_least_s = lambda t, s: 1 - torch.prod(torch.sort(t, descending=True)[0][:s]) + exists_at_most_s = lambda t, s: 1 - torch.prod(torch.sort(1 - t, descending=True)[0][:len(t) - s]) + exists_exactly_s = lambda t, s: 1 - self.calc_probabilities(t, s)[s] + + if self.ifLog: self.myLogger.debug("%s called with: %s"%(logicMethodName,var)) var = self._fixVar(var) @@ -382,8 +443,8 @@ def countVar(self, _, *var, onlyConstrains = False, limitOp = '==', limit = 1, l tOne = torch.ones(1, device=self.current_device, requires_grad=True, dtype=torch.float64) if limitOp == '==': # == limit - countLoss = torch.minimum(torch.maximum(torch.abs(torch.sub(limit, varSum)), tZero), tOne) # min(max(abs(varSum - limit), 0), 1) - + #countLoss = torch.minimum(torch.maximum(torch.abs(torch.sub(limit, varSum)), tZero), tOne) # min(max(abs(varSum - limit), 0), 1) + countLoss=exists_exactly_s(varSum, limit) if onlyConstrains: return countLoss else: @@ -391,9 +452,17 @@ def countVar(self, _, *var, onlyConstrains = False, limitOp = '==', limit = 1, l return countSuccess else: if limitOp == '>=': # > limit + #Existsl + if onlyConstrains: + if limit ==1:return exists_at_least_one(varSum) + else: return exists_at_least_s(varSum, limit) countSuccess = torch.minimum(torch.maximum(torch.sub(varSum, limit), tZero), tOne) # min(max(varSum - limit, 0), 1) elif limitOp == '<=': # < limit + #atmostL + + if onlyConstrains: + return exists_at_most_s(varSum, limit) countSuccess = torch.minimum(torch.maximum(torch.sub(limit, varSum), tZero), tOne) # min(max(limit - varSum, 0), 1) if onlyConstrains: diff --git a/domiknows/solver/lcLossSampleBooleanMethods.py b/domiknows/solver/lcLossSampleBooleanMethods.py index 2a99a1be..1fdac0d0 100644 --- a/domiknows/solver/lcLossSampleBooleanMethods.py +++ b/domiknows/solver/lcLossSampleBooleanMethods.py @@ -190,16 +190,16 @@ def countVar(self, _, *var, onlyConstrains = False, limitOp = '==', limit = 1, l elif limitOp == '==': fixedVar.append(torch.zeros([self.sampleSize], device=self.current_device)) # -- - + # V = 100, limitTensor = torch.full([self.sampleSize], limit, device = self.current_device) # Calculate sum varSum = torch.zeros([self.sampleSize], device=self.current_device) if fixedVar: - varSum = fixedVar[0].int() - + varSum = fixedVar[0].clone() + for i in range(1, len(fixedVar)): - varSum.add_(fixedVar[i].int()) + varSum.add_(fixedVar[i]) # Check condition if limitOp == '>=': diff --git a/test_regr/examples/PMDExistL/graph.py b/test_regr/examples/PMDExistL/graph.py new file mode 100644 index 00000000..a456f690 --- /dev/null +++ b/test_regr/examples/PMDExistL/graph.py @@ -0,0 +1,31 @@ +def get_graph(args): + from domiknows.graph import Graph, Concept, Relation + from domiknows.graph.logicalConstrain import orL, existsL, ifL, notL, andL, atMostL, atLeastL, exactL + from domiknows.graph import EnumConcept + Graph.clear() + Concept.clear() + Relation.clear() + + with Graph('global_PMD') as graph: + a = Concept(name='a') + b = Concept(name='b') + a_contain_b, = a.contains(b) + + b_answer = b(name="answer_b", ConceptClass=EnumConcept, values=["zero", "one"]) + + expected_zero = b_answer.__getattr__("zero") + expected_one = b_answer.__getattr__("one") + + expected_value = expected_zero if args.expected_value == 0 else expected_one + + if args.atLeastL and args.atMostL: + atLeastL(expected_value, args.expected_atLeastL) + atMostL(expected_value, args.expected_atMostL) + elif args.atMostL: + atMostL(expected_value, args.expected_atMostL) + elif args.atLeastL: + atLeastL(expected_value, args.expected_atLeastL) + else: + exactL(expected_value, args.expected_atLeastL) + + return graph, a, b, a_contain_b, b_answer diff --git a/test_regr/examples/PMDExistL/main.py b/test_regr/examples/PMDExistL/main.py new file mode 100644 index 00000000..a35ad6cd --- /dev/null +++ b/test_regr/examples/PMDExistL/main.py @@ -0,0 +1,97 @@ +import sys +import argparse +from typing import Any +import numpy as np +import torch +from domiknows.sensor.pytorch.sensors import ReaderSensor +from domiknows.sensor.pytorch.relation_sensors import EdgeSensor, FunctionalSensor +from domiknows.sensor.pytorch.learners import ModuleLearner +from domiknows.sensor import Sensor +from domiknows.program.metric import MacroAverageTracker +from domiknows.program.loss import NBCrossEntropyLoss +from domiknows.program.lossprogram import PrimalDualProgram, SampleLossProgram +from domiknows.program.model.pytorch import SolverModel + +from utils import TestTrainLearner, return_contain, create_dataset, evaluate_model, train_model + +sys.path.append('../../../../domiknows/') +from graph import get_graph + +Sensor.clear() + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Machine Learning Experiment") + parser.add_argument("--counting_tnorm", choices=["G", "P", "L", "SP"], default="SP", help="The tnorm method to use for the counting constraints") + parser.add_argument("--atLeastL", default=False, type=bool, help="Use at least L constraint") + parser.add_argument("--atMostL", default=False, type=bool, help="Use at most L constraint") + parser.add_argument("--epoch", default=500, type=int, help="Number of training epochs") + parser.add_argument("--expected_atLeastL", default=3, type=int, help="Expected value for at least L") + parser.add_argument("--expected_atMostL", default=3, type=int, help="Expected value for at most L") + parser.add_argument("--expected_value", default=0, type=int, help="Expected value") + parser.add_argument("--N", default=10, type=int, help="N parameter") + parser.add_argument("--M", default=8, type=int, help="M parameter") + parser.add_argument("--model", default="sampling", type=str, help="Model Types [Sampling/PMD]") + parser.add_argument("--sample_size", default=-1, type=int, help="Sample size for sampling program") + return parser.parse_args() + + +def setup_graph(args: argparse.Namespace, a: Any, b: Any, a_contain_b: Any, b_answer: Any) -> None: + a["index"] = ReaderSensor(keyword="a") + b["index"] = ReaderSensor(keyword="b") + b["temp_answer"] = ReaderSensor(keyword="label") + b[a_contain_b] = EdgeSensor(b["index"], a["index"], relation=a_contain_b, forward=return_contain) + b[b_answer] = ModuleLearner(a_contain_b, "index", module=TestTrainLearner(args.N), device="cpu") + b[b_answer] = FunctionalSensor(a_contain_b, "temp_answer", forward=lambda _, label: label, label=True) + + +def main(args: argparse.Namespace): + np.random.seed(0) + torch.manual_seed(0) + + graph, a, b, a_contain_b, b_answer = get_graph(args) + dataset = create_dataset(args.N, args.M) + setup_graph(args, a, b, a_contain_b, b_answer) + if args.model == "sampling": + # print("sampling") + program = SampleLossProgram( + graph, SolverModel, poi=[a, b, b_answer], + inferTypes=['local/argmax'], + loss=MacroAverageTracker(NBCrossEntropyLoss()), + sample=True, + sampleSize=args.sample_size, + sampleGlobalLoss=False, + beta=1, device='cpu', tnorm="L", counting_tnorm=args.counting_tnorm + ) + else: + program = PrimalDualProgram( + graph, SolverModel, poi=[a, b, b_answer], + inferTypes=['local/argmax'], + loss=MacroAverageTracker(NBCrossEntropyLoss()), + beta=10, device='cpu', tnorm="L", counting_tnorm=args.counting_tnorm) + + expected_value = args.expected_value + train_model(program, dataset, num_epochs=2) + + before_count = evaluate_model(program, dataset, b_answer).get(expected_value, 0) + train_model(program, dataset, args.epoch, constr_loss_only=True) + + pass_test_case = True + actual_count = evaluate_model(program, dataset, b_answer).get(expected_value, 0) + + if args.atLeastL: + pass_test_case &= (actual_count >= args.expected_atLeastL) + if args.atMostL: + pass_test_case &= (actual_count <= args.expected_atMostL) + if not args.atLeastL and not args.atMostL: + pass_test_case &= (actual_count == args.expected_atLeastL) + + print(f"Test case {'PASSED' if pass_test_case else 'FAILED'}") + print( + f"expected_value, before_count, actual_count,pass_test_case): {expected_value, before_count, actual_count, pass_test_case}") + return pass_test_case, before_count, actual_count + + +if __name__ == "__main__": + args = parse_arguments() + main(args) diff --git a/test_regr/examples/PMDExistL/testcase.py b/test_regr/examples/PMDExistL/testcase.py new file mode 100644 index 00000000..359a4c97 --- /dev/null +++ b/test_regr/examples/PMDExistL/testcase.py @@ -0,0 +1,98 @@ +import subprocess +import sys +import os +import itertools +import json +from concurrent.futures import ProcessPoolExecutor, as_completed +from collections import defaultdict + + +def run_test(params): + # Convert params to command-line arguments + args = [] + for key, value in params.items(): + if not str(value) == "False": + args.extend([f'--{key}', str(value)]) + + # Get the path to the Python interpreter in the current virtual environment + python_executable = sys.executable + + # Construct the command to run the main script + cmd = [python_executable, 'main.py'] + args + + # Run the command in a subprocess + try: + # Use UTF-8 encoding and replace any characters that can't be decoded + result = subprocess.run(cmd, capture_output=True, text=True, encoding='utf-8', errors='replace', check=True) + return params, True, result.stdout + except subprocess.CalledProcessError as e: + return params, False, e.stderr + + +def run_tests(param_combinations): + """Run tests with different combinations of input arguments.""" + # Generate all combinations of parameters + keys, values = zip(*param_combinations.items()) + print(keys, values) + combinations = [dict(zip(keys, v)) for v in itertools.product(*values)] + + # Run tests for each combination using ProcessPoolExecutor with max_workers=4 + results = defaultdict(list) + total_combinations = len(combinations) + + with ProcessPoolExecutor(max_workers=4) as executor: + for i, (params, test_passed, output) in enumerate(executor.map(run_test, combinations), 1): + counting_tnorm = params['counting_tnorm'] + results[counting_tnorm].append((params, 'PASSED' in output, output)) + print(f"\nCompleted test {i}/{total_combinations}:") + print(f"Parameters: {params}") + print(output.split("\n")[-2]) + print(f"Passed: {'PASSED' in output}") + + # Print summary of results per counting_tnorm + print("\n--- Test Summary ---") + for counting_tnorm, tnorm_results in results.items(): + passed_tests = sum(1 for _, passed, _ in tnorm_results if passed) + print(f"\nResults for counting_tnorm = {counting_tnorm}:") + print(f"Passed {passed_tests} out of {len(tnorm_results)} tests.") + + # Print details of failed tests for this counting_tnorm + if passed_tests < len(tnorm_results): + print(f"Failed tests for counting_tnorm = {counting_tnorm}:") + for params, passed, output in tnorm_results: + if not passed: + print(f"Parameters: {params}") + print(f"Error output: {output}") + + +if __name__ == "__main__": + # Define the parameter combinations to test PMD + PMD_combinations = { + 'counting_tnorm': ["G", "P", "SP", "L"], + 'atLeastL': [True, False], + 'atMostL': [True, False], + 'epoch': [1000], + 'expected_atLeastL': [2], + 'expected_atMostL': [5], + 'expected_value': [0, 1], + 'N': [10], + 'M': [8], + 'model': ["PMD"], + } + # run_tests(PMD_combinations) + + # Define the parameter combinations to test sampling model + sampling_combinations = { + 'counting_tnorm': ["G"], + 'atLeastL': [True, False], + 'atMostL': [True, False], + 'epoch': [1000], + 'expected_atLeastL': [1, 2, 3], + 'expected_atMostL': [3, 4, 5], + 'expected_value': [0, 1], + 'N': [10], + 'M': [8], + 'model': ["sampling"], + "sample_size": [10, 20, 50, 100, 200, -1] # maximum 2^8 = 256 + } + run_tests(sampling_combinations) diff --git a/test_regr/examples/PMDExistL/utils.py b/test_regr/examples/PMDExistL/utils.py new file mode 100644 index 00000000..6b321631 --- /dev/null +++ b/test_regr/examples/PMDExistL/utils.py @@ -0,0 +1,84 @@ +from typing import List, Dict, Any +import numpy as np +import torch +from torch import nn +from tqdm import tqdm +from domiknows.sensor.pytorch.learners import TorchLearner +from domiknows.program.model.base import Mode +from domiknows.program.lossprogram import PrimalDualProgram + + +class DummyLearner(TorchLearner): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.stack((torch.ones(len(x)) * 4, torch.ones(len(x)) * 6), dim=-1) + + +class TestTrainLearner(nn.Module): + def __init__(self, input_size: int): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(input_size, input_size), + nn.Linear(input_size, 2) + ) + + def forward(self, _, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +def return_contain(b: torch.Tensor, _: Any) -> torch.Tensor: + return torch.ones(len(b)).unsqueeze(-1) + + +def create_dataset(N: int, M: int) -> List[Dict[str, Any]]: + return [{ + "a": [0], + "b": [((np.random.rand(N) - np.random.rand(N))).tolist() for _ in range(M)], + "label": [1] * M + }] + + +def train_model(program: PrimalDualProgram, dataset: List[Dict[str, Any]], + num_epochs: int, constr_loss_only: bool = False) -> None: + program.model.train() + program.model.reset() + program.cmodel.train() + program.cmodel.reset() + program.model.mode(Mode.TRAIN) + + opt = torch.optim.Adam(program.model.parameters(), lr=1e-2) + copt = torch.optim.Adam(program.cmodel.parameters(), lr=1e-3) + + for _ in tqdm(range(num_epochs), desc="Training with PMD"): + for data in dataset: + opt.zero_grad() + copt.zero_grad() + mloss, _, *output = program.model(data) + closs, *_ = program.cmodel(output[1]) + + if constr_loss_only: + loss = mloss * 0 + (closs if torch.is_tensor(closs) else 0) + else: + loss = mloss + + if loss.item() < 0: + print("Negative loss", loss.item()) + break + if loss: + loss.backward() + opt.step() + copt.step() + + +def evaluate_model(program: PrimalDualProgram, dataset: List[Dict[str, Any]], b_answer: Any) -> Dict[int, int]: + program.model.eval() + program.model.reset() + program.cmodel.eval() + program.cmodel.reset() + program.model.mode(Mode.TEST) + + final_result_count = {} + for datanode in program.populate(dataset=dataset): + for child in datanode.getChildDataNodes(): + pred = child.getAttribute(b_answer, 'local/argmax').argmax().item() + final_result_count[pred] = final_result_count.get(pred, 0) + 1 + return final_result_count \ No newline at end of file