Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change "repeat" to "repeat_interleave" and "tile" uniformly #827

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/animatediff/ad/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def construct(self, q, k, v, mask):
else:
finfo_type = np.float32
max_neg_value = -np.finfo(finfo_type).max
mask = mask.repeat(self.heads, axis=0)
mask = mask.repeat_interleave(self.heads, dim=0)
mask = ops.expand_dims(mask, axis=1)
sim.masked_fill(mask, max_neg_value)

Expand Down
2 changes: 1 addition & 1 deletion examples/animatediff/ad/modules/diffusionmodules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if dim % 2:
embedding = ops.concat((embedding, ops.ZerosLike()(embedding[:, :1])), axis=-1)
else:
embedding = ops.reshape(timesteps.repeat(dim), (-1, dim))
embedding = ops.reshape(timesteps.repeat_interleave(dim), (-1, dim))
return embedding


Expand Down
2 changes: 1 addition & 1 deletion examples/animatediff/ad/utils/cond_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def transform_conditional_images(image_paths, H, W, random_crop=True, normalize=
if normalize:

def image_norm(image):
image = image.mean(dim=0, keepdim=True).repeat(3, 1, 1)
image = image.mean(axis=0, keepdims=True).repeat(3, axis=0)
image -= image.min()
image /= image.max()
return image
Expand Down
2 changes: 1 addition & 1 deletion examples/dynamicrafter/lvdm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def construct(self, q, k, v, k_ip, v_ip, out_ip, mask):
else:
finfo_type = np.float32
max_neg_value = -np.finfo(finfo_type).max
mask = mask.repeat(self.head_num, axis=0)
mask = mask.repeat_interleave(self.head_num, dim=0)
mask = ops.expand_dims(mask, axis=1)
sim.masked_fill(mask, max_neg_value)

Expand Down
2 changes: 1 addition & 1 deletion examples/dynamicrafter/lvdm/modules/networks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if dim % 2:
embedding = ops.concat((embedding, ops.ZerosLike()(embedding[:, :1])), axis=-1)
else:
embedding = ops.reshape(timesteps.repeat(dim), (-1, dim))
embedding = ops.reshape(timesteps.repeat_interleave(dim), (-1, dim))
return embedding


Expand Down
2 changes: 1 addition & 1 deletion examples/instantmesh/models/renderer/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def forward_grid(self, planes, grid_size: int, aabb: ms.Tensor = None):
dtype=planes.dtype,
)
.unsqueeze(0)
.repeat(planes.shape[0], 1, 1)
.tile((planes.shape[0], 1, 1))
)
assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
N = planes.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion examples/magvit/tools/inflate_vae2d_to_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def inflate(vae2d_ckpt, save_fp):
shape_3d_new = tuple([shape_3d[0] // 2]) + shape_3d[1:]
new_w = ms.ops.zeros(shape_3d_new, dtype=weights.dtype)
new_w[:, :, -1, :, :] = weights
new_w = new_w.repeat(2, 0)
new_w = new_w.tile((2, 1, 1, 1, 1))

new_w = ms.Parameter(new_w, name=key_3d)
new_state_dict[key_3d] = new_w
Expand Down
2 changes: 1 addition & 1 deletion examples/opensora_hpcai/opensora/utils/cond_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def transform_conditional_images(image_paths, H, W, random_crop=True, normalize=
if normalize:

def image_norm(image):
image = image.mean(dim=0, keepdim=True).repeat(3, 1, 1)
image = image.mean(axis=0, keepdims=True).repeat(3, axis=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeat(3, 1, 1) -> tile(3, 1, 1)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是numpy的mean和repeat,这里做了下参数修改。torch的repeat(3,1,1)和numpy的repeat(3,axis=0)对于(1, H, W)的image的生成结果是一样的。

image -= image.min()
image /= image.max()
return image
Expand Down
2 changes: 1 addition & 1 deletion examples/opensora_pku/opensora/eval/cal_fvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
x = x.tile((1, 1, 3, 1, 1))

# permute BTCHW -> BCTHW
x = x.permute(0, 2, 1, 3, 4)
Expand Down
2 changes: 1 addition & 1 deletion examples/opensora_pku/opensora/eval/cal_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(3, axis=2)
x = x.tile((1, 1, 3, 1, 1))

# value range [0, 1] -> [-1, 1]
x = x * 2 - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _tile(self, x):
if d < self.n_codes:
n_repeats = (self.n_codes + d - 1) // d
std = 0.01 / np.sqrt(ew)
x = x.repeat(n_repeats, 1)
x = x.tile((n_repeats, 1))
Songyuanwei marked this conversation as resolved.
Show resolved Hide resolved
x = x + ops.randn_like(x) * std
return x

Expand Down
2 changes: 1 addition & 1 deletion examples/stable_diffusion_v2/depth_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def prepare_conditions(depth, txt, num_samples=1, height=512, width=512, vae_sca

# repeat to [bs, 1, h_z, w_z]
depth = np.expand_dims(depth, axis=[0, 1])
depth = depth.repeat(num_samples, axis=0)
depth = depth.tile((num_samples, 1, 1, 1))
assert len(depth.shape) == 4 and depth.shape[1] == 1, f"expect shape [n, 1, h, w], but got {depth.shape}"

depth = Tensor(depth, dtype=mstype.float32)
Expand Down
2 changes: 1 addition & 1 deletion examples/stable_diffusion_v2/inference/sd_lite_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(args):
)
init_image = Image.open(args.inputs.image_path).convert("RGB")
img = img_processor.preprocess(init_image, height=args.inputs.H, width=args.inputs.W)
inputs["img"] = img.repeat(batch_size, axis=0).asnumpy()
inputs["img"] = img.tile((batch_size, 1, 1, 1)).asnumpy()
init_timestep = min(int(args.sampling_steps * args.inputs.strength), args.sampling_steps)
t_start = max(args.sampling_steps - init_timestep, 0)
inputs["timesteps"] = inputs["timesteps"][t_start * scheduler.order :]
Expand Down
6 changes: 3 additions & 3 deletions examples/stable_diffusion_v2/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def make_batch_sd(image, mask, txt, num_samples=1):
masked_image = image * (mask < 0.5)

batch = {
"image": image.repeat(num_samples, axis=0),
"image": image.repeat_interleave(num_samples, dim=0),
"txt": num_samples * [txt],
"mask": mask.repeat(num_samples, axis=0),
"masked_image": masked_image.repeat(num_samples, axis=0),
"mask": mask.repeat_interleave(num_samples, dim=0),
"masked_image": masked_image.repeat_interleave(num_samples, dim=0),
}
return batch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if dim % 2:
embedding = ops.concat((embedding, ops.ZerosLike()(embedding[:, :1])), axis=-1)
else:
embedding = ops.reshape(timesteps.repeat(dim), (-1, dim))
embedding = ops.reshape(timesteps.repeat_interleave(dim), (-1, dim))
return embedding


Expand Down
6 changes: 3 additions & 3 deletions examples/story_diffusion/utils/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@ def __call__(
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)

bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, axis=1)
prompt_embeds = prompt_embeds.tile((1, num_images_per_prompt, 1))
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(num_images_per_prompt, axis=1).view(
pooled_prompt_embeds = pooled_prompt_embeds.tile((1, num_images_per_prompt, 1)).view(
bs_embed * num_images_per_prompt, -1
)
pooled_prompt_embeds_arr.append(pooled_prompt_embeds)
Expand Down Expand Up @@ -472,7 +472,7 @@ def __call__(
text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = ops.cat([add_time_ids, add_time_ids], axis=0)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, axis=0)
add_time_ids = add_time_ids.tile((batch_size * num_images_per_prompt, 1))

# print(latents.shape)
# print(add_time_ids.shape)
Expand Down
2 changes: 1 addition & 1 deletion examples/sv3d/modules/temporal_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def construct(self, x: Tensor, timesteps: Tensor, skip_video: bool = False):
x = x.transpose(0, 2, 3, 1).reshape(b, -1, c) # b c h w -> b (h w) c

x_mix = x
num_frames = ops.arange(timesteps).repeat(b // timesteps)
num_frames = ops.arange(timesteps).repeat_interleave(b // timesteps)
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
Expand Down
2 changes: 1 addition & 1 deletion examples/sv3d/sgm/modules/diffusionmodules/guiders.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def construct(self, x: Tensor, sigma: Tensor) -> Tensor:
x_c = x_c.reshape(-1, self.num_frames, *x_c.shape[1:])

scale = ops.linspace(self.min_scale, self.max_scale, self.num_frames)[None, :]
scale = scale.repeat(x_u.shape[0], axis=0) # 1 t -> b t
scale = scale.repeat_interleave(x_u.shape[0], dim=0) # 1 t -> b t
scale = append_dims(scale, x_u.ndim)

out = x_u + scale * (x_c - x_u)
Expand Down
6 changes: 3 additions & 3 deletions examples/sv3d/simple_video_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_batch(
"cond_frames_without_noise": cond_frames_without_noise,
}
batch_uc = {}
batch["cond_aug"] = Tensor(self.cond_aug).repeat(math.prod([1, self.num_frames]))
batch["cond_aug"] = Tensor(self.cond_aug).repeat_interleave(math.prod([1, self.num_frames]))

for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], Tensor):
Expand Down Expand Up @@ -128,10 +128,10 @@ def __call__(self, image: Tensor) -> Tensor:

for k in ["crossattn", "concat"]:
uc[k] = self.expand_dims_ops(uc[k], 1)
uc[k] = uc[k].repeat(self.num_frames, axis=1)
uc[k] = uc[k].repeat_interleave(self.num_frames, dim=1)
uc[k] = uc[k].flatten(order="C", start_dim=0, end_dim=1)
c[k] = self.expand_dims_ops(c[k], 1)
c[k] = c[k].repeat(self.num_frames, axis=1)
c[k] = c[k].repeat_interleave(self.num_frames, dim=1)
c[k] = c[k].flatten(order="C", start_dim=0, end_dim=1)

randn = Tensor(self.randn_n)
Expand Down
4 changes: 2 additions & 2 deletions examples/svd/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def construct(
tokens = (cond_frames_without_noise, fps_id, motion_bucket_id, cond_frames, cond_aug)

vector, crossattn, concat = self.conditioner(*tokens)
crossattn = crossattn.repeat(num_frames, axis=0)
concat = concat.repeat(num_frames, axis=0)
crossattn = crossattn.repeat_interleave(num_frames, dim=0)
concat = concat.repeat_interleave(num_frames, dim=0)

c_skip, c_out, c_in, c_noise = self.denoiser(sigmas, noised_input.ndim)
model_output = self.model(
Expand Down
2 changes: 1 addition & 1 deletion examples/svd/modules/diffusionmodules/guiders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def construct(self, x: Tensor, sigma: Tensor, num_frames: int) -> Tensor:
x_c = x_c.reshape(-1, num_frames, *x_c.shape[1:])

scale = ops.linspace(self.min_scale, self.max_scale, num_frames)[None, :]
scale = scale.repeat(x_u.shape[0], axis=0) # 1 t -> b t
scale = scale.repeat_interleave(x_u.shape[0], dim=0) # 1 t -> b t
scale = append_dims(scale, x_u.ndim)

out = x_u + scale * (x_c - x_u)
Expand Down
8 changes: 4 additions & 4 deletions examples/svd/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def construct(
sigmas = self.sigma_sampler(b)
if self.sigma_cond is not None:
sigma_cond = self.sigma_cond(sigmas)
sigma_cond = sigma_cond.repeat(self.n_copies, axis=0) # b d -> (b t) d
sigmas = sigmas.repeat(self.n_copies, axis=0) # b -> (b t)
sigma_cond = sigma_cond.repeat_interleave(self.n_copies, dim=0) # b d -> (b t) d
sigmas = sigmas.repeat_interleave(self.n_copies, dim=0) # b -> (b t)
noise = ops.randn_like(vid)
vid = vid + noise * append_dims(sigmas, vid.ndim)

Expand All @@ -61,7 +61,7 @@ def construct(

vid = vid.reshape(-1, self.n_cond_frames, *vid.shape[1:]) # (b t) c h w -> b t c h w
vid = vid.reshape(vid.shape[0], -1, *vid.shape[3:]) # b t c h w -> b (t c) h w
vid = vid.repeat(self.n_copies, axis=0) # b (t c) h w -> (b s) (t c) h w
vid = vid.repeat_interleave(self.n_copies, dim=0) # b (t c) h w -> (b s) (t c) h w

if self.sigma_cond is not None:
return vid, sigma_cond
Expand All @@ -84,6 +84,6 @@ def __init__(
def construct(self, vid: Tensor) -> Tensor:
vid = self.open_clip(vid)
vid = vid.reshape(-1, self.n_cond_frames, vid.shape[1]) # (b t) d -> b t d
vid = vid.repeat(self.n_copies, axis=0) # b t d -> (b s) t d
vid = vid.repeat_interleave(self.n_copies, dim=0) # b t d -> (b s) t d

return vid
2 changes: 1 addition & 1 deletion examples/svd/modules/temporal_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def construct(self, x: Tensor, timesteps: Tensor, skip_video: bool = False):
x = x.transpose(0, 2, 3, 1).reshape(b, -1, c) # b c h w -> b (h w) c

x_mix = x
num_frames = ops.arange(timesteps).repeat(b // timesteps)
num_frames = ops.arange(timesteps).repeat_interleave(b // timesteps)
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
Expand Down
4 changes: 2 additions & 2 deletions examples/svd/modules/unet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ def construct(

time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = time_context_first_timestep.repeat(h * w, axis=0) # b ... -> (b n) ...
time_context = time_context_first_timestep.repeat_interleave(h * w, dim=0) # b ... -> (b n) ...
elif time_context is not None and not self.use_spatial_context:
time_context = time_context.repeat(h * w, axis=0) # b ... -> (b n) ...
time_context = time_context.repeat_interleave(h * w, dim=0) # b ... -> (b n) ...
if time_context.ndim == 2:
time_context = time_context.expand_dims(1) # b c -> b 1 c

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = ops.arange(seq_length)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = seq_ids[None, None, :].tile((batch_size, seq_length, 1)) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
Expand Down
2 changes: 1 addition & 1 deletion examples/t2v_turbo/lvdm/modules/encoders/ip_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
)

def construct(self, x):
latents = self.latents.repeat(x.shape[0], 1, 1)
latents = self.latents.tile((x.shape[0], 1, 1))

x = self.proj_in(x)

Expand Down
4 changes: 2 additions & 2 deletions examples/t2v_turbo/pipeline/t2v_turbo_ms_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _encode_prompt(

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.tile((1, num_videos_per_prompt, 1))
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)

# Don't need to get uncond prompt embedding because of LCM Guided Distillation
Expand Down Expand Up @@ -177,7 +177,7 @@ def __call__(
bs = batch_size * num_videos_per_prompt

# 6. Get Guidance Scale Embedding
w = ms.Tensor(guidance_scale).repeat(bs)
w = ms.Tensor(guidance_scale).tile((bs,))
w_embedding = self.get_w_embedding(w, embedding_dim=256)

# 7. LCM MultiStep Sampling Loop:
Expand Down
4 changes: 2 additions & 2 deletions examples/t2v_turbo/pipeline/t2v_turbo_vc2_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _encode_prompt(

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.tile((1, num_videos_per_prompt, 1))
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)

# Don't need to get uncond prompt embedding because of LCM Guided Distillation
Expand Down Expand Up @@ -171,7 +171,7 @@ def __call__(
bs = batch_size * num_videos_per_prompt

# 6. Get Guidance Scale Embedding
w = ms.Tensor(guidance_scale).repeat(bs)
w = ms.Tensor(guidance_scale).tile((bs,))
w_embedding = self.get_w_embedding(w, embedding_dim=256)

# 7. LCM MultiStep Sampling Loop:
Expand Down
4 changes: 2 additions & 2 deletions examples/t2v_turbo/viclip/viclip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def inflate_weight(weight_2d, time_dim, center=True):
logger.info(f"Init center: {center}")
if center:
weight_3d = ops.zeros(*weight_2d.shape)
weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
weight_3d = weight_3d.unsqueeze(2).tile((1, 1, time_dim, 1, 1))
middle_idx = time_dim // 2
weight_3d[:, :, middle_idx, :, :] = weight_2d
else:
weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
weight_3d = weight_2d.unsqueeze(2).tile((1, 1, time_dim, 1, 1))
weight_3d = weight_3d / time_dim
return weight_3d

Expand Down
2 changes: 1 addition & 1 deletion examples/videocomposer/vc/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def rearange_in(x):
else:
finfo_type = np.float32
max_neg_value = -np.finfo(finfo_type).max
mask = mask.repeat(self.heads, axis=0)
mask = mask.repeat_interleave(self.heads, dim=0)
mask = ops.expand_dims(mask, axis=1)
sim.masked_fill(mask, max_neg_value)

Expand Down
10 changes: 7 additions & 3 deletions mindone/models/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ def construct(self, x, mask=None):
if mask is not None:
if mask.ndim == 2:
# mask shape is (batch_size, key_len)
mask = ops.expand_dims(mask, axis=1).repeat(q_n, axis=1) # (b, k_n) -> (b, 1, k_n) -> (b, q_n, k_n)
mask = ops.expand_dims(mask, axis=1).repeat_interleave(
q_n, dim=1
) # (b, k_n) -> (b, 1, k_n) -> (b, q_n, k_n)
mask = ops.select(
~mask,
ops.ones((q_b, q_n, k_n), self.dtype) * (-ms.numpy.inf),
Expand All @@ -308,7 +310,9 @@ def construct(self, x, mask=None):
elif mask.ndim == 3:
# mask shape is (batch_size, query_len, key_len), the query_len maybe one
if mask.shape[-2] == 1:
mask = mask.repeat(q_n, axis=-1) # manually broadcast to key length to avoid FA shape error
mask = mask.repeat_interleave(
q_n, dim=-1
) # manually broadcast to key length to avoid FA shape error
assert mask.shape[-2] == q_n, "Expect mask shape to be (bs, query_len, key_len), "
f"but the mask query length {mask.shape[-2]} is different from the input query length {q_n}"
assert mask.shape[-1] == k_n, "Expect mask shape to be (bs, query_len, key_len), "
Expand All @@ -335,7 +339,7 @@ def construct(self, x, mask=None):
k = self._rearange_in(k, h)
v = self._rearange_in(v, h)
if mask is not None and mask.shape[0] != q.shape[0]:
mask = mask.repeat(h, axis=0) # (b, q_n, k_n) -> (b*h, q_n, k_n)
mask = mask.repeat_interleave(h, dim=0) # (b, q_n, k_n) -> (b*h, q_n, k_n)
out = self.attention(q, k, v, mask)
# (b*h, n, d) -> (b, n, h*d)
out = self._rearange_out(out, h)
Expand Down
2 changes: 1 addition & 1 deletion mindone/models/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def construct(self, x, c, mask=None, spatial_freq=None):
k = ops.reshape(k, (B * h, T + L, hd))
v = ops.reshape(v, (B * h, T + L, hd))
if mask is not None and mask.shape[0] != q.shape[0]:
mask = mask.repeat(h, axis=0)
mask = mask.repeat_interleave(h, dim=0)
out = self.attention(q, k, v, mask)
out = out.swapaxes(1, 2).view(B, T + L, C) # b, nh, T+L, hd -> b, T+L, nh*hd=C

Expand Down
Loading