-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcustom_types.py
87 lines (69 loc) · 1.85 KB
/
custom_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# import open3d
import enum
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from constants import DEBUG
from typing import Tuple, List, Union, Callable, Type, Iterator, Dict, Set, Optional, Any, Sized, Iterable
from types import DynamicClassAttribute
from enum import Enum, unique
import torch.optim.optimizer
import torch.utils.data
if DEBUG or True:
seed = 99
torch.manual_seed(seed)
np.random.seed(seed)
N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VN = Optional[ARRAY]
VNS = Optional[ARRAYS]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]
V_Mesh = Tuple[ARRAY, ARRAY]
T_Mesh = Tuple[T, Optional[T]]
T_Mesh_T = Union[T_Mesh, T]
COLORS = Union[T, ARRAY, Tuple[int, int, int]]
D = torch.device
CPU = torch.device('cpu')
def get_device(device_id: int) -> D:
if not torch.cuda.is_available():
if device_id >= 0:
print("warning GPU is not available")
return CPU
device_id = min(torch.cuda.device_count() - 1, device_id)
return torch.device(f'cuda:{device_id}')
CUDA = get_device
Optimizer = torch.optim.Adam
Dataset = torch.utils.data.Dataset
DataLoader = torch.utils.data.DataLoader
Subset = torch.utils.data.Subset
@unique
class Padding(Enum):
ZERO = 0
REFLECTIVE = 1
class ModelType(enum.Enum):
@DynamicClassAttribute
def value(self) -> str:
return super(ModelType, self).value
PPE3 = 'ppe3'
PPE2 = 'ppe2'
EXPLICIT = 'exp'
PE = 'pe'
ReLU = 'relu'
PPE = 'ppe'
SIREN = 'siren'
HYBRID = 'hybrid'
class LossType(enum.Enum):
@DynamicClassAttribute
def value(self) -> str:
return super(LossType, self).value
CROSS = 'cross'
HINGE = 'hinge'
IN_OUT = 'in_out'