-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. 数据定义 2. 格式转换 3. 批量处理 4. 模型定义 5. 损失函数/优化器/学习率调度器设置 6. 分阶段保存和测试模型
- Loading branch information
Showing
23 changed files
with
764 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/28 下午5:06 | ||
@file: hmdb51.py | ||
@author: zj | ||
@description: | ||
""" | ||
|
||
import torchvision.transforms as transforms | ||
|
||
from tsn.data.hmdb51 import HMDB51 | ||
|
||
|
||
def get_transform(): | ||
transform = transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
]) | ||
|
||
return transform | ||
|
||
|
||
if __name__ == '__main__': | ||
transform = get_transform() | ||
|
||
data_set = HMDB51('/home/zj/zhonglian/mmaction2/data/hmdb51/rawframes', '/home/zj/zhonglian/mmaction2/data/hmdb51', | ||
split=1, num_seg=8, train=True, transform=transform) | ||
print(data_set) | ||
print(len(data_set)) | ||
image, target = data_set.__getitem__(100) | ||
print(image.shape) | ||
print(target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/22 下午4:20 | ||
@file: predict.py | ||
@author: zj | ||
@description: | ||
""" | ||
|
||
import cv2 | ||
import torch | ||
from tsn.data.build import build_test_transform | ||
from tsn.model.build import build_model | ||
from tsn.util.checkpoint import CheckPointer | ||
|
||
if __name__ == '__main__': | ||
epoches = 10 | ||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
|
||
model = build_model(num_classes=360).to(device) | ||
output_dir = './outputs' | ||
checkpointer = CheckPointer(model, save_dir=output_dir) | ||
checkpointer.load() | ||
|
||
transform = build_test_transform() | ||
img_path = 'imgs/RotNet.png' | ||
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) | ||
print(img.shape) | ||
res_img = transform(img).unsqueeze(0) | ||
print(res_img.shape) | ||
|
||
outputs = model(res_img.to(device)) | ||
_, preds = torch.max(outputs, 1) | ||
print(preds) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/21 下午7:52 | ||
@file: build.py | ||
@author: zj | ||
@description: | ||
""" | ||
|
||
import os | ||
import torch | ||
|
||
from tsn.data.build import build_dataloader | ||
from tsn.model.build import build_model, build_criterion | ||
from tsn.optim.build import build_optimizer, build_lr_scheduler | ||
from tsn.engine.build import train_model | ||
from tsn.util.checkpoint import CheckPointer | ||
|
||
if __name__ == '__main__': | ||
epoches = 10 | ||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
|
||
data_loaders, data_sizes = build_dataloader() | ||
|
||
criterion = build_criterion() | ||
model = build_model(num_classes=51).to(device) | ||
optimizer = build_optimizer(model) | ||
lr_scheduler = build_lr_scheduler(optimizer) | ||
|
||
output_dir = './outputs' | ||
if not os.path.exists(output_dir): | ||
os.mkdir(output_dir) | ||
checkpointer = CheckPointer(model, optimizer=optimizer, scheduler=lr_scheduler, save_dir=output_dir, | ||
save_to_disk=True, logger=None) | ||
|
||
train_model('MobileNet_v2', model, criterion, optimizer, lr_scheduler, data_loaders, data_sizes, checkpointer, | ||
epoches=epoches, device=device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/5 下午4:23 | ||
@file: __init__.py.py | ||
@author: zj | ||
@description: | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/21 下午7:20 | ||
@file: __init__.py.py | ||
@author: zj | ||
@description: | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/21 下午7:20 | ||
@file: build.py | ||
@author: zj | ||
@description: | ||
""" | ||
|
||
from torch.utils.data import DataLoader | ||
import torchvision.transforms as transforms | ||
|
||
from .hmdb51 import HMDB51 | ||
|
||
|
||
def build_train_transform(): | ||
transform = transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.Resize((224, 224)), | ||
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
transforms.RandomErasing() | ||
]) | ||
|
||
return transform, None | ||
|
||
|
||
def build_test_transform(): | ||
transform = transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
]) | ||
|
||
return transform | ||
|
||
|
||
def build_dataset(): | ||
data_dir = '/home/zj/zhonglian/mmaction2/data/hmdb51/rawframes' | ||
annotation_dir = '/home/zj/zhonglian/mmaction2/data/hmdb51' | ||
|
||
train_transform, _ = build_train_transform() | ||
test_transform = build_test_transform() | ||
|
||
train_dataset = HMDB51(data_dir, annotation_dir, num_seg=3, split=1, | ||
train=True, transform=train_transform) | ||
test_dataset = HMDB51(data_dir, annotation_dir, num_seg=3, split=1, | ||
train=True, transform=test_transform) | ||
|
||
return {'train': train_dataset, 'test': test_dataset}, {'train': len(train_dataset), 'test': len(test_dataset)} | ||
|
||
|
||
def build_dataloader(): | ||
data_sets, data_sizes = build_dataset() | ||
|
||
train_dataloader = DataLoader(data_sets['train'], batch_size=32, shuffle=True, num_workers=8) | ||
test_dataloader = DataLoader(data_sets['test'], batch_size=32, shuffle=True, num_workers=8) | ||
|
||
return {'train': train_dataloader, 'test': test_dataloader}, data_sizes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/28 下午4:37 | ||
@file: hmdb51.py | ||
@author: zj | ||
@description: | ||
""" | ||
|
||
import cv2 | ||
import random | ||
import os | ||
import numpy as np | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class HMDB51(Dataset): | ||
|
||
def __init__(self, data_dir, annotation_dir, num_seg=3, split=1, train=True, transform=None): | ||
if train: | ||
annotation_path = os.path.join(annotation_dir, f'hmdb51_train_split_{split}_rawframes.txt') | ||
else: | ||
annotation_path = os.path.join(annotation_dir, f'hmdb51_val_split_{split}_rawframes.txt') | ||
|
||
if not os.path.isfile(annotation_path): | ||
raise ValueError(f'{annotation_path}不是文件路径') | ||
|
||
self.data_dir = data_dir | ||
self.transform = transform | ||
self.num_seg = num_seg | ||
|
||
video_list = list() | ||
img_num_list = list() | ||
cate_list = list() | ||
with open(annotation_path, 'r') as f: | ||
lines = f.readlines() | ||
for line in lines: | ||
dir_name, img_num, cate = line.strip().split(' ') | ||
|
||
video_list.append(dir_name) | ||
img_num_list.append(int(img_num)) | ||
cate_list.append(int(cate)) | ||
self.video_list = video_list | ||
self.img_num_list = img_num_list | ||
self.cate_list = cate_list | ||
|
||
def __getitem__(self, index: int): | ||
""" | ||
从选定的视频文件夹中随机选取T帧 | ||
:return: (T, C, H, W),其中T表示num_seg | ||
""" | ||
assert index < len(self.video_list) | ||
target = self.cate_list[index] | ||
|
||
num_list = random.sample(range(self.img_num_list[index]), self.num_seg) | ||
video_path = os.path.join(self.data_dir, self.video_list[index]) | ||
|
||
image_list = list() | ||
for num in num_list: | ||
image_path = os.path.join(video_path, 'img_{:0>5d}.jpg'.format(num)) | ||
img = cv2.imread(image_path) | ||
|
||
if self.transform: | ||
img = self.transform(img) | ||
image_list.append(img) | ||
image = torch.stack(image_list) | ||
|
||
return image, target | ||
|
||
def __len__(self) -> int: | ||
return len(self.video_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/22 上午11:09 | ||
@file: mnist.py | ||
@author: zj | ||
@description: | ||
""" | ||
|
||
import numpy as np | ||
from PIL import Image | ||
from torchvision.datasets import FashionMNIST | ||
|
||
|
||
class FMNIST(FashionMNIST): | ||
|
||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): | ||
super().__init__(root, train, transform, target_transform, download) | ||
|
||
def __getitem__(self, index): | ||
""" | ||
先进行图像旋转,再进行图像预处理 | ||
Args: | ||
index (int): Index | ||
Returns: | ||
tuple: (image, target) where target is rotation angle of the image | ||
""" | ||
img, target = self.data[index], int(self.targets[index]) | ||
|
||
if self.target_transform is not None: | ||
img, target = self.target_transform(img.numpy()) | ||
|
||
# doing this so that it is consistent with all other datasets | ||
# to return a PIL Image | ||
img = Image.fromarray(img, mode='L') | ||
else: | ||
img = Image.fromarray(img.numpy(), mode='L') | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
# print(img.shape, target) | ||
return img, target | ||
|
||
|
||
if __name__ == '__main__': | ||
dataset = FMNIST('../../data/', train=True) | ||
img, target = dataset.__getitem__(10) | ||
print(np.array(img).shape) | ||
print(target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/22 上午11:14 | ||
@file: rotate.py | ||
@author: zj | ||
@description: | ||
""" | ||
|
||
import math | ||
import random | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
class Rotate: | ||
|
||
def __call__(self, img: np.ndarray): | ||
assert isinstance(img, np.ndarray) | ||
|
||
angle = random.randint(0, 359) | ||
rotate_img = rotate(img, angle) | ||
|
||
return rotate_img, angle | ||
|
||
|
||
def rotate(img, degree): | ||
h, w = img.shape[:2] | ||
center = (w // 2, h // 2) | ||
|
||
dst_h = int(w * math.fabs(math.sin(math.radians(degree))) + h * math.fabs(math.cos(math.radians(degree)))) | ||
dst_w = int(h * math.fabs(math.sin(math.radians(degree))) + w * math.fabs(math.cos(math.radians(degree)))) | ||
|
||
matrix = cv2.getRotationMatrix2D(center, degree, 1) | ||
matrix[0, 2] += dst_w // 2 - center[0] | ||
matrix[1, 2] += dst_h // 2 - center[1] | ||
dst_img = cv2.warpAffine(img, matrix, (dst_w, dst_h), borderValue=(255, 255, 255)) | ||
|
||
# imshow(img, 'src') | ||
# imshow(dst_img, 'dst') | ||
# cv2.waitKey(0) | ||
return dst_img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@date: 2020/8/21 下午7:29 | ||
@file: __init__.py.py | ||
@author: zj | ||
@description: | ||
""" |
Oops, something went wrong.