Skip to content

Commit

Permalink
feat: add parallel support via mint.distributed for ocr/aes
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyZhou952 committed Feb 11, 2025
1 parent 2a23eee commit d034800
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 23 deletions.
23 changes: 15 additions & 8 deletions tools/t2v_curation/pipeline/scoring/aesthetic/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.dataset as ds
from mindspore.communication import get_rank, get_group_size, init
from mindspore import Tensor, load_checkpoint, save_checkpoint, load_param_into_net
from mindspore.mint.distributed import init_process_group, get_rank, get_world_size, all_gather
from tqdm import tqdm
from transformers import AutoProcessor

from pipeline.datasets.utils import extract_frames, pil_loader, is_video
from pipeline.scoring.utils import merge_scores, NUM_FRAMES_POINTS

__dir__ = os.path.dirname(os.path.abspath(__file__))
mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../.."))
mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../../.."))
sys.path.insert(0, mindone_lib_path)

from mindone.transformers import CLIPModel
Expand Down Expand Up @@ -116,7 +116,7 @@ def main():
if not args.use_cpu:
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
ms.set_auto_parallel_context(parallel_mode = ms.ParallelMode.DATA_PARALLEL)
init()
init_process_group()

model = AestheticScorer()
preprocess = model.processor
Expand All @@ -126,7 +126,7 @@ def main():
raw_dataset = VideoTextDataset(args.meta_path, transform=preprocess, num_frames=args.num_frames)
if not args.use_cpu:
rank_id = get_rank()
rank_size = get_group_size()
rank_size = get_world_size()
dataset = ds.GeneratorDataset(source=raw_dataset, column_names=['index', 'images'], shuffle=False,
num_shards = rank_size, shard_id = rank_id)
else:
Expand Down Expand Up @@ -154,14 +154,21 @@ def main():
scores_list.extend(scores_np.tolist())

if not args.use_cpu:
allgather = ops.AllGather()
indices_list = Tensor(indices_list, dtype=ms.int64)
scores_list = Tensor(scores_list, dtype=ms.float32)
indices_list = allgather(indices_list).asnumpy().tolist()
scores_list = allgather(scores_list).asnumpy().tolist()

indices_list_all = [Tensor(np.zeros(indices_list.shape, dtype=np.int64)) for _ in range(rank_size)]
scores_list_all = [Tensor(np.zeros(scores_list.shape, dtype=np.float32)) for _ in range(rank_size)]

all_gather(indices_list_all, indices_list)
all_gather(scores_list_all, scores_list)

concat = ops.Concat(axis = 0)
indices_list_all = concat(indices_list_all).asnumpy().tolist()
scores_list_all = concat(scores_list_all).asnumpy().tolist()

if args.use_cpu or (not args.use_cpu and rank_id == 0):
meta_local = merge_scores([(indices_list, scores_list)], raw_dataset.meta, column="aes")
meta_local = merge_scores([(indices_list_all, scores_list_all)], raw_dataset.meta, column="aes")
meta_local.to_csv(out_path, index = False)
print(meta_local)
print(f"New meta with aesthetic scores saved to '{out_path}'.")
Expand Down
1 change: 0 additions & 1 deletion tools/t2v_curation/pipeline/scoring/ocr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def create_parser():
type=str,
help="directory containing the recognition model checkpoint best.ckpt, or path to a specific checkpoint file.",
) # determine the network weights
# parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
parser.add_argument(
"--rec_image_shape",
type=str,
Expand Down
78 changes: 66 additions & 12 deletions tools/t2v_curation/pipeline/scoring/ocr/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import pandas as pd
import numpy as np
import mindspore as ms
import mindspore.ops as ops
import mindspore.dataset as ds
from mindspore.communication import get_rank, get_group_size, init
from mindspore import Tensor
from mindspore.mint.distributed import init_process_group, get_rank, get_world_size
from mindspore.mint.distributed import all_gather, all_gather_object
from tqdm import tqdm

from pipeline.datasets.utils import extract_frames, pil_loader, is_video
from pipeline.scoring.utils import merge_scores
from config import parse_args
from text_system import TextSystem

Expand Down Expand Up @@ -58,22 +62,22 @@ def main():

ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
init()
init_process_group()

# initialize the TextSystem
text_system = TextSystem(args)

raw_dataset = VideoTextDataset(args.meta_path)
rank_id = get_rank()
rank_size = get_group_size()
rank_size = get_world_size()
dataset = ds.GeneratorDataset(
source=raw_dataset,
column_names=['index', 'images', 'height', 'width'],
shuffle=False,
num_shards=rank_size,
shard_id=rank_id
)
# TODO: Batch size > 1 only supports images with the same shapes
# DO NOT set Batch size > 1 unless all the images have the same shape, else error
dataset = dataset.batch(args.bs, drop_remainder=False)
iterator = dataset.create_dict_iterator(num_epochs=1)

Expand Down Expand Up @@ -162,14 +166,64 @@ def main():
total_text_percentage_list.extend(batch_total_text_percentage)

meta_local = raw_dataset.meta.copy()
meta_local['ocr'] = ocr_results_list

if compute_num_boxes:
meta_local['num_boxes'] = num_boxes_list
if compute_max_single_text_box_area_percentage:
meta_local['max_single_percentage'] = max_single_percentage_list
if compute_total_text_area_percentage:
meta_local['total_text_percentage'] = total_text_percentage_list
if rank_size > 1:
# indices
indices_list = Tensor(np.array(indices_list), dtype=ms.int64)
indices_list_all = [Tensor(np.zeros(indices_list.shape, dtype=np.int64)) for _ in range(rank_size)]
all_gather(indices_list_all, indices_list)
indices_list_all = ops.Concat(axis=0)(indices_list_all).asnumpy().tolist()

# num_boxes
if compute_num_boxes:
num_boxes_list = Tensor(np.array(num_boxes_list), dtype=ms.int32)
num_boxes_list_all = [Tensor(np.zeros(num_boxes_list.shape, dtype=np.int32)) for _ in range(rank_size)]
all_gather(num_boxes_list_all, num_boxes_list)
num_boxes_list_all = ops.Concat(axis=0)(num_boxes_list_all).asnumpy().tolist()
else:
num_boxes_list_all = None

# max_single_percentage_list
if compute_max_single_text_box_area_percentage:
max_single_percentage_list = Tensor(np.array(max_single_percentage_list), dtype=ms.float32)
max_single_percentage_list_all = [Tensor(np.zeros(max_single_percentage_list.shape, dtype=np.float32)) for _ in range(rank_size)]
all_gather(max_single_percentage_list_all, max_single_percentage_list)
max_single_percentage_list_all = ops.Concat(axis=0)(max_single_percentage_list_all).asnumpy().tolist()
else:
max_single_percentage_list_all = None

# total_text_percentage_list
if compute_total_text_area_percentage:
total_text_percentage_list = Tensor(np.array(total_text_percentage_list), dtype=ms.float32)
total_text_percentage_list_all = [Tensor(np.zeros(total_text_percentage_list.shape, dtype=np.float32)) for _ in range(rank_size)]
all_gather(total_text_percentage_list_all, total_text_percentage_list)
total_text_percentage_list_all = ops.Concat(axis=0)(total_text_percentage_list_all).asnumpy().tolist()
else:
total_text_percentage_list_all = None

# ocr_results_list
ocr_results_list_all = [None] * rank_size
all_gather_object(ocr_results_list_all, ocr_results_list)
# Flatten the list-of-lists from each process into a single list
ocr_results_list_all = sum(ocr_results_list_all, [])

meta_local = merge_scores([(indices_list_all, ocr_results_list_all)], raw_dataset.meta, column="ocr")
if compute_num_boxes:
meta_local = merge_scores([(indices_list_all, num_boxes_list_all)], meta_local, column="num_boxes")
if compute_max_single_text_box_area_percentage:
meta_local = merge_scores([(indices_list_all, max_single_percentage_list_all)], meta_local,
column="max_single_percentage")
if compute_total_text_area_percentage:
meta_local = merge_scores([(indices_list_all, total_text_percentage_list_all)], meta_local,
column="total_text_percentage")
else: # store directly without gathering
meta_local = raw_dataset.meta.copy()
meta_local['ocr'] = ocr_results_list
if compute_num_boxes:
meta_local['num_boxes'] = num_boxes_list
if compute_max_single_text_box_area_percentage:
meta_local['max_single_percentage'] = max_single_percentage_list
if compute_total_text_area_percentage:
meta_local['total_text_percentage'] = total_text_percentage_list

meta_local.to_csv(out_path, index=False)
print(meta_local)
Expand Down
Loading

0 comments on commit d034800

Please sign in to comment.