Skip to content

Commit

Permalink
tree hot
Browse files Browse the repository at this point in the history
  • Loading branch information
raynardj committed Apr 30, 2022
1 parent ddb55c8 commit 8c20ecf
Show file tree
Hide file tree
Showing 9 changed files with 972 additions and 1,762 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__
build/
dist/
*.egg-info
.hypothesis

nbs/data/*
nbs/*.db
Expand Down
2 changes: 1 addition & 1 deletion forgebox/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.0.1"
37 changes: 37 additions & 0 deletions forgebox/treehot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from category import Category
import numpy as np

def cache(f):
data = dict()
def wrapper(name, parent_map):
if name in data:
return data[name]
rt = f(name, parent_map)
data[name]=rt
return rt
return wrapper

@cache
def find_ancestor_map(name, parent_map):
if name not in parent_map:
return []
else:
return [name,]+find_ancestor_map(parent_map[name], parent_map)

def tree_hot(cate, name, ancestor_map):
target = np.zeros(len(cate), dtype=int)
target[cate.c2i[ancestor_map[name]]]=1
return target

def get_depth_map(cate, ancestor_map):
cate.depth_map = dict(
(k, len(v)) for k,v in ancestor_map.items())
return cate.depth_map

def get_depth_map_array(cate, ancestor_map):
cate.depth_map_array = np.vectorize(cate.get_depth_map(ancestor_map).get)(cate.i2c)
return cate.depth_map_array

Category.tree_hot = tree_hot
Category.get_depth_map = get_depth_map
Category.get_depth_map_array = get_depth_map_array
Loading

0 comments on commit 8c20ecf

Please sign in to comment.