-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
70 lines (59 loc) · 2.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
from data_collection.data_collector import Data_collection
from predictions.live_predict import LivePredict
from src.train_model import ModelTraining
from database.database_operations import CassandraCRUD
import sys
def data_collection():
collection = Data_collection(class_name=sys.argv[2], collection_type=sys.argv[3])
try:
if sys.argv[3] == "pose":
collection.pose_collection()
elif sys.argv[3] == "face":
collection.face_collection()
except IndexError:
print('main.py has arguments:\nclass name: name of class to capture data')
def test_prediction():
try:
prediction = LivePredict()
if sys.argv[2] == "pose":
prediction.show_pose()
elif sys.argv[2] == "face":
prediction.show_face(sys.argv[3])
elif sys.argv[2] == "db":
prediction.show_both()
except IndexError as e:
print("Please Select from\n1.Pose\n2.Face")
print(e)
def db_crud():
crud = CassandraCRUD("test_key")
if sys.argv[2] == "show":
crud.show_both()
def train():
try:
training = ModelTraining()
if sys.argv[2] == "face":
training.train_face_model(data_directory="models/faces_embeddings.npz",
model_output_directory="models/faces_embeddings.npz",
# model_input_directory='models/face_model.h5',
)
elif sys.argv[2] == "pose":
training.train_model(data_directory="raw_data/training/landmarks.csv",
keras_model_output_directory="models/pose_model.h5",
# keras_model_input_directory='models/pose_model.h5',
# pca_model_input_directory="models/pca_model.joblib",
pca_model_output_directory="models/pose_pca.joblib",
n_components=15)
except IndexError:
print("please give the following parameters train_model(n_components, epochs)")
if __name__ == '__main__':
try:
globals()[sys.argv[1]]()
except KeyError as e:
print("An KeyError occurred:", e)
except IndexError:
print("please enter the operation name")
print("data_collection(class_name, collection_type)\ntest_prediction()\ntrain_face()")