Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Feb 12, 2025
1 parent f9f1bde commit 8300d88
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
13 changes: 9 additions & 4 deletions examples/hunyuanvideo/hyvideo/dataset/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def parse_data_file(self, data_file_path):
self.dataset = json.load(f)
else:
raise ValueError("Only support json and csv file now!")
assert len(self.dataset) > 0, "No data found in the data file."
sample = self.dataset[0]
assert (
self.caption_column in sample
), f"Expected caption column `{self.caption_column}` in dataset, but got {sample.keys()}"

assert (
self.file_column in sample
), f"Expected file path column {self.file_column} in dataset, but got {sample.keys()}"

def __len__(self):
return self.num_captions
Expand All @@ -63,13 +72,9 @@ def read_captions(self, dataset):
def __getitem__(self, idx_text):
idx = self.caption_sample_indices[idx_text]
row = self.dataset[idx]
assert (
self.caption_column in row
), f"Expected caption column {self.caption_column} in dataset, but got {row.keys()}"
captions = row[self.caption_column]
if isinstance(captions, str):
captions = [captions]
assert self.file_column in row, f"Expected file path column {self.file_column} in dataset, but got {row.keys()}"
file_path = row[self.file_column]
# get the caption id
first_text_index = self.caption_sample_indices.index(idx)
Expand Down
3 changes: 1 addition & 2 deletions examples/hunyuanvideo/hyvideo/text_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Optional, Tuple

from hyvideo.constants import PRECISION_TO_TYPE, TEXT_ENCODER_PATH, TOKENIZER_PATH
from hyvideo.utils.helpers import set_model_param_dtype
from transformers import AutoTokenizer, CLIPTokenizer
from transformers.utils import ModelOutput
Expand All @@ -12,8 +13,6 @@
from mindone.transformers.models.llama.modeling_llama import ALL_LAYERNORM_LAYERS
from mindone.utils.amp import auto_mixed_precision

from constants import PRECISION_TO_TYPE, TEXT_ENCODER_PATH, TOKENIZER_PATH


def use_default(value, default):
return value if value is not None else default
Expand Down
5 changes: 3 additions & 2 deletions examples/hunyuanvideo/scripts/run_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def parse_args():
)
parser.add_argument(
"--file-column",
default="path",
help="The column of file path in `data_file_path`. Defaults to `path`.",
default="video",
help="The column of file path in `data_file_path`. Defaults to `video`.",
)
parser.add_argument(
"--caption-column",
Expand Down Expand Up @@ -309,6 +309,7 @@ def main(args):
print(f"rank_id {rank_id}, device_num {device_num}")

# build dataloader for large amount of captions
print_banner("data init")
if args.data_file_path is not None:
assert isinstance(args.data_file_path, str), "Expect data_file_path to be a string!"
assert Path(args.data_file_path).exists(), "data_file_path does not exist!"
Expand Down

0 comments on commit 8300d88

Please sign in to comment.