Skip to content

Commit

Permalink
feat(tsn): 实现demo版本
Browse files Browse the repository at this point in the history
1. 数据定义
2. 格式转换
3. 批量处理
4. 模型定义
5. 损失函数/优化器/学习率调度器设置
6. 分阶段保存和测试模型
  • Loading branch information
zjykzj committed Aug 28, 2020
1 parent 30f6fdb commit 68ea159
Show file tree
Hide file tree
Showing 23 changed files with 764 additions and 0 deletions.
Binary file added imgs/TSN.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions imgs/TSN.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 34 additions & 0 deletions test/hmdb51.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/28 下午5:06
@file: hmdb51.py
@author: zj
@description:
"""

import torchvision.transforms as transforms

from tsn.data.hmdb51 import HMDB51


def get_transform():
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
])

return transform


if __name__ == '__main__':
transform = get_transform()

data_set = HMDB51('/home/zj/zhonglian/mmaction2/data/hmdb51/rawframes', '/home/zj/zhonglian/mmaction2/data/hmdb51',
split=1, num_seg=8, train=True, transform=transform)
print(data_set)
print(len(data_set))
image, target = data_set.__getitem__(100)
print(image.shape)
print(target)
34 changes: 34 additions & 0 deletions tools/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/22 下午4:20
@file: predict.py
@author: zj
@description:
"""

import cv2
import torch
from tsn.data.build import build_test_transform
from tsn.model.build import build_model
from tsn.util.checkpoint import CheckPointer

if __name__ == '__main__':
epoches = 10
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = build_model(num_classes=360).to(device)
output_dir = './outputs'
checkpointer = CheckPointer(model, save_dir=output_dir)
checkpointer.load()

transform = build_test_transform()
img_path = 'imgs/RotNet.png'
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
print(img.shape)
res_img = transform(img).unsqueeze(0)
print(res_img.shape)

outputs = model(res_img.to(device))
_, preds = torch.max(outputs, 1)
print(preds)
37 changes: 37 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/21 下午7:52
@file: build.py
@author: zj
@description:
"""

import os
import torch

from tsn.data.build import build_dataloader
from tsn.model.build import build_model, build_criterion
from tsn.optim.build import build_optimizer, build_lr_scheduler
from tsn.engine.build import train_model
from tsn.util.checkpoint import CheckPointer

if __name__ == '__main__':
epoches = 10
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data_loaders, data_sizes = build_dataloader()

criterion = build_criterion()
model = build_model(num_classes=51).to(device)
optimizer = build_optimizer(model)
lr_scheduler = build_lr_scheduler(optimizer)

output_dir = './outputs'
if not os.path.exists(output_dir):
os.mkdir(output_dir)
checkpointer = CheckPointer(model, optimizer=optimizer, scheduler=lr_scheduler, save_dir=output_dir,
save_to_disk=True, logger=None)

train_model('MobileNet_v2', model, criterion, optimizer, lr_scheduler, data_loaders, data_sizes, checkpointer,
epoches=epoches, device=device)
8 changes: 8 additions & 0 deletions tsn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/5 下午4:23
@file: __init__.py.py
@author: zj
@description:
"""
8 changes: 8 additions & 0 deletions tsn/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/21 下午7:20
@file: __init__.py.py
@author: zj
@description:
"""
62 changes: 62 additions & 0 deletions tsn/data/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/21 下午7:20
@file: build.py
@author: zj
@description:
"""

from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from .hmdb51 import HMDB51


def build_train_transform():
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomErasing()
])

return transform, None


def build_test_transform():
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

return transform


def build_dataset():
data_dir = '/home/zj/zhonglian/mmaction2/data/hmdb51/rawframes'
annotation_dir = '/home/zj/zhonglian/mmaction2/data/hmdb51'

train_transform, _ = build_train_transform()
test_transform = build_test_transform()

train_dataset = HMDB51(data_dir, annotation_dir, num_seg=3, split=1,
train=True, transform=train_transform)
test_dataset = HMDB51(data_dir, annotation_dir, num_seg=3, split=1,
train=True, transform=test_transform)

return {'train': train_dataset, 'test': test_dataset}, {'train': len(train_dataset), 'test': len(test_dataset)}


def build_dataloader():
data_sets, data_sizes = build_dataset()

train_dataloader = DataLoader(data_sets['train'], batch_size=32, shuffle=True, num_workers=8)
test_dataloader = DataLoader(data_sets['test'], batch_size=32, shuffle=True, num_workers=8)

return {'train': train_dataloader, 'test': test_dataloader}, data_sizes
73 changes: 73 additions & 0 deletions tsn/data/hmdb51.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/28 下午4:37
@file: hmdb51.py
@author: zj
@description:
"""

import cv2
import random
import os
import numpy as np

import torch
from torch.utils.data import Dataset


class HMDB51(Dataset):

def __init__(self, data_dir, annotation_dir, num_seg=3, split=1, train=True, transform=None):
if train:
annotation_path = os.path.join(annotation_dir, f'hmdb51_train_split_{split}_rawframes.txt')
else:
annotation_path = os.path.join(annotation_dir, f'hmdb51_val_split_{split}_rawframes.txt')

if not os.path.isfile(annotation_path):
raise ValueError(f'{annotation_path}不是文件路径')

self.data_dir = data_dir
self.transform = transform
self.num_seg = num_seg

video_list = list()
img_num_list = list()
cate_list = list()
with open(annotation_path, 'r') as f:
lines = f.readlines()
for line in lines:
dir_name, img_num, cate = line.strip().split(' ')

video_list.append(dir_name)
img_num_list.append(int(img_num))
cate_list.append(int(cate))
self.video_list = video_list
self.img_num_list = img_num_list
self.cate_list = cate_list

def __getitem__(self, index: int):
"""
从选定的视频文件夹中随机选取T帧
:return: (T, C, H, W),其中T表示num_seg
"""
assert index < len(self.video_list)
target = self.cate_list[index]

num_list = random.sample(range(self.img_num_list[index]), self.num_seg)
video_path = os.path.join(self.data_dir, self.video_list[index])

image_list = list()
for num in num_list:
image_path = os.path.join(video_path, 'img_{:0>5d}.jpg'.format(num))
img = cv2.imread(image_path)

if self.transform:
img = self.transform(img)
image_list.append(img)
image = torch.stack(image_list)

return image, target

def __len__(self) -> int:
return len(self.video_list)
51 changes: 51 additions & 0 deletions tsn/data/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/22 上午11:09
@file: mnist.py
@author: zj
@description:
"""

import numpy as np
from PIL import Image
from torchvision.datasets import FashionMNIST


class FMNIST(FashionMNIST):

def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super().__init__(root, train, transform, target_transform, download)

def __getitem__(self, index):
"""
先进行图像旋转,再进行图像预处理
Args:
index (int): Index
Returns:
tuple: (image, target) where target is rotation angle of the image
"""
img, target = self.data[index], int(self.targets[index])

if self.target_transform is not None:
img, target = self.target_transform(img.numpy())

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img, mode='L')
else:
img = Image.fromarray(img.numpy(), mode='L')

if self.transform is not None:
img = self.transform(img)

# print(img.shape, target)
return img, target


if __name__ == '__main__':
dataset = FMNIST('../../data/', train=True)
img, target = dataset.__getitem__(10)
print(np.array(img).shape)
print(target)
42 changes: 42 additions & 0 deletions tsn/data/rotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/22 上午11:14
@file: rotate.py
@author: zj
@description:
"""

import math
import random
import cv2
import numpy as np


class Rotate:

def __call__(self, img: np.ndarray):
assert isinstance(img, np.ndarray)

angle = random.randint(0, 359)
rotate_img = rotate(img, angle)

return rotate_img, angle


def rotate(img, degree):
h, w = img.shape[:2]
center = (w // 2, h // 2)

dst_h = int(w * math.fabs(math.sin(math.radians(degree))) + h * math.fabs(math.cos(math.radians(degree))))
dst_w = int(h * math.fabs(math.sin(math.radians(degree))) + w * math.fabs(math.cos(math.radians(degree))))

matrix = cv2.getRotationMatrix2D(center, degree, 1)
matrix[0, 2] += dst_w // 2 - center[0]
matrix[1, 2] += dst_h // 2 - center[1]
dst_img = cv2.warpAffine(img, matrix, (dst_w, dst_h), borderValue=(255, 255, 255))

# imshow(img, 'src')
# imshow(dst_img, 'dst')
# cv2.waitKey(0)
return dst_img
8 changes: 8 additions & 0 deletions tsn/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/8/21 下午7:29
@file: __init__.py.py
@author: zj
@description:
"""
Loading

0 comments on commit 68ea159

Please sign in to comment.