Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refractor monarch variable names to boost readbility #40

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 67 additions & 37 deletions bert/src/mm/blockdiag_butterfly_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,68 +46,98 @@ def blockdiag_butterfly_multiply_reference(x, w1_bfly, w2_bfly, version=2):
out2 = rearrange(out2, 'b (l s) -> b (s l)', l=l)
return out2


class BlockdiagButterflyMultiply(torch.autograd.Function):

"""This is a faster implementation, with careful memory copies for the fastest
bmm performance.
The backward pass is also written manually with careful memory copies.
Arguments:
x: (batch, n)
w1_bfly: (k, q, p), where k = n / p
w2_bfly: (l, s, r), where l = k * q / r = n * q / (p * r)
w1_bfly: (nblocks, blk_blk2_in, blk_sz)
w2_bfly: (nblocks, blk_sz, blk_r)
Outputs:
out: (batch, m), where m = l * s = n * s * q / (p * r)
"""

@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16)
def forward(ctx, x, w1_bfly, w2_bfly):
@torch.amp.custom_fwd(cast_inputs=torch.bfloat16)
def forward(ctx, x, w1_bfly, w2_bfly, debug_out1=False):
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = np.prod(batch_shape)
k, q, p = w1_bfly.shape
l, s, r = w2_bfly.shape
assert k * p == n
assert l * r == k * q
x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
out1 = torch.empty(batch_dim, k, q, device=x.device, dtype=x.dtype).transpose(0, 1)
out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1)
out1 = out1.transpose(0, 1).reshape(batch_dim, r, l).transpose(-1, -2).contiguous().transpose(0, 1)
out2 = torch.empty(batch_dim, l, s, device=x.device, dtype=x.dtype).transpose(0, 1)
out2 = torch.bmm(out1, w2_bfly.transpose(-1, -2), out=out2)
out2 = out2.permute(1, 2, 0).reshape(*batch_shape, s * l)
ctx.save_for_backward(x, w1_bfly, w2_bfly, out1)
seq_dim = np.prod(batch_shape)

w1_bfly = w1_bfly.to(x.dtype)
w2_bfly = w2_bfly.to(x.dtype)

# Typically blk1_out = blk2_in and nblocks1 = nblocks2
# e.g. (4, 4, 1024)
nblocks1, blk1_out, blk1_in = w1_bfly.shape
nblocks2, blk2_out, blk2_in = w2_bfly.shape
assert nblocks1 * blk1_in == n
assert nblocks2 * blk2_in == nblocks1 * blk1_out

# Typical shape for Llama 7B on Math reasoning: (4, 666, 1024)
x_reshaped = x.reshape(seq_dim, nblocks1, blk1_in).transpose(0, 1)
out1 = torch.empty(nblocks1, seq_dim, blk1_out, device=x.device, dtype=x.dtype)

# (nblocks1, seq_dim, blk1_in) @ (nblocks1, blk1_in, blk1_out)
out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1) # -> (nblocks1, seq_dim, blk1_out)
del x_reshaped

# Feature shuffling
out1 = (
out1.transpose(0, 1).reshape(seq_dim, blk2_in, nblocks2).permute(2, 0, 1)
) # (seq_dim, nblocks2, blk1_out) -> (.., blk2_in, nblocks2) -> (nblocks2, seq_dim, blk2_in)

out2 = torch.empty(nblocks2, seq_dim, blk2_out, device=x.device, dtype=x.dtype)
out2 = torch.bmm(
out1, w2_bfly.transpose(-1, -2), out=out2
) # (nblocks2, seq_dim, blk2_in) @ (nblocks2, blk2_in, blk2_out) -> (nblocks2, seq_dim, blk2_out)

out2 = out2.permute(1, 2, 0).reshape(
*batch_shape, blk2_out * nblocks2
) # (nblocks2, seq_dim, blk2_out) -> (seq_dim, nblocks2 * blk2_out )

ctx.save_for_backward(x, w1_bfly, w2_bfly, out1, None, None)
if debug_out1:
return out2, out1
return out2

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd
def backward(ctx, dout):
x, w1_bfly, w2_bfly, out1 = ctx.saved_tensors
x, w1_bfly, w2_bfly, out1, *_ = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = np.prod(batch_shape)
k, q, p = w1_bfly.shape
l, s, r = w2_bfly.shape
# assert k * p == n
# assert l * r == k * q
seq_dim = np.prod(batch_shape)
nblocks1, blk1_out, blk1_in = w1_bfly.shape
nblocks2, blk2_out, blk2_in = w2_bfly.shape

dx, dw1_bfly, dw2_bfly = None, None, None
# dout_reshaped = dout.reshape(batch_dim, sqrtn, sqrtn).permute(2, 1, 0).contiguous()
dout_reshaped = dout.reshape(batch_dim, s, l).transpose(-1, -2).contiguous()
dout_reshaped = dout_reshaped.transpose(0, 1)

dout_reshaped = dout.reshape(seq_dim, blk2_out, nblocks2).transpose(-1, -2)
dout_reshaped = dout_reshaped.transpose(0, 1).contiguous() # (nblocks2, seq_dim, blk2_out)
if ctx.needs_input_grad[2]:
# dw2_bfly = torch.empty(l, s, r, device=w2_bfly.device, dtype=w2_bfly.dtype)
# dw2_bfly = torch.empty(nblocks2, blk2_out, blk2_in, device=w2_bfly.device, dtype=w2_bfly.dtype)
# dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1, out=dw2_bfly)

# (nblocks2, blk2_out, seq_dim) @ (nblocks2, seq_dim, blk1_out) -> (nblocks2, blk2_out, blk1_out)
dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1.conj())
if ctx.needs_input_grad[1] or ctx.needs_input_grad[0]:
dout1 = torch.empty(batch_dim, l, r, device=x.device, dtype=x.dtype).transpose(0, 1)
dout1 = torch.bmm(dout_reshaped, w2_bfly.conj(), out=dout1)
dout1 = dout1.transpose(0, 1).transpose(-1, -2).contiguous().reshape(batch_dim, k, q).transpose(0, 1)
# dout1 = dout1.permute(1, 2, 0).contiguous().transpose(0, 1)
dout1 = torch.empty(nblocks2, seq_dim, blk2_in, device=x.device, dtype=x.dtype)
dout1 = torch.bmm(dout_reshaped, w2_bfly.conj(), out=dout1) # -> (nblocks2, seq_dim, blk2_in)
del dout_reshaped
# dout1 = dout1.transpose(0, 1).transpose(-1, -2).contiguous().reshape(seq_dim, nblocks1, blk1_out).transpose(0, 1)
# NOTE: We do NOT need contiguous in between? This should save memory & time
dout1 = (
dout1.permute(1, 2, 0).reshape(seq_dim, nblocks1, blk1_out).transpose(0, 1)
) # -> (nblocks1, seq_dim, blk2_in)
if ctx.needs_input_grad[0]:
dx = torch.empty(batch_dim, k, p, device=x.device, dtype=x.dtype)
dx = torch.empty(seq_dim, nblocks1, blk1_in, device=x.device, dtype=x.dtype)
# (nblocks1, seq_dim, blk1_out) @ (nblocks1, blk1_out, blk1_in) -> (nblocks1, seq_dim, blk1_in)
dx = torch.bmm(dout1, w1_bfly.conj(), out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n)
if ctx.needs_input_grad[1]:
x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1)
x_reshaped = x.reshape(seq_dim, nblocks1, blk1_in).transpose(0, 1)
# (nblocks2, blk2_in, seq_dim) @ (nblocks2, seq_dim, blk1_out) -> (nblocks2, blk2_in, blk1_in)
dw1_bfly = torch.bmm(dout1.transpose(-1, -2), x_reshaped.conj())
return dx, dw1_bfly, dw2_bfly
return dx, dw1_bfly, dw2_bfly, None, None


blockdiag_butterfly_multiply = BlockdiagButterflyMultiply.apply
42 changes: 23 additions & 19 deletions bert/src/mm/blockdiag_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,48 +35,52 @@ def blockdiag_multiply_reference(x, weight):


class BlockdiagMultiply(torch.autograd.Function):

"""This is a faster implementation, with careful memory copies for the fastest
bmm performance.
The backward pass is also written manually with careful memory copies.
Arguments:
x: (..., n)
weight: (nblocks, q, n / nblocks)
weight: (nblocks, q, n / blk2_out)
Outputs:
out: (..., nblocks * q)
out: (..., nblocks * blk1_out)
"""

@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16)
@torch.cuda.amp.custom_fwd()
def forward(ctx, x, weight):
ctx.save_for_backward(x, weight)
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = np.prod(batch_shape)
nblocks, q, p = weight.shape
assert nblocks * p == n
x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1)
out = torch.empty(batch_dim, nblocks, q, device=x.device, dtype=x.dtype).transpose(0, 1)
out = torch.bmm(x_reshaped, weight.transpose(-1, -2), out=out).transpose(0, 1)
return out.reshape(*batch_shape, nblocks * q)
seq_dim = np.prod(batch_shape)
nblocks, blk_out, blk_in = weight.shape
assert nblocks * blk_in == n
x_reshaped = x.view(seq_dim, nblocks, blk_in).transpose(0, 1) # (nblocks, seq_dim, p)

out = torch.empty(nblocks, seq_dim, blk_out, device=x.device, dtype=x.dtype)
out = torch.bmm(x_reshaped, weight.transpose(-1, -2), out=out).transpose(
0, 1
) # (nblocks, seq_dim, blk_sz) @ (nblocks, blk_sz, blk_r) -> (nblocks, seq_dim, blk1_out)
return out.reshape(*batch_shape, nblocks * blk_out)

@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout):
x, weight = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = np.prod(batch_shape)
nblocks, q, p = weight.shape
assert nblocks * p == n
seq_dim = np.prod(batch_shape)
nblocks, blk_out, blk_in = weight.shape
assert nblocks * blk_in == n
dx, dweight = None, None
dout_reshaped = dout.reshape(batch_dim, nblocks, q).transpose(0, 1)
dout_reshaped = dout.reshape(seq_dim, nblocks, blk_out).transpose(0, 1)
if ctx.needs_input_grad[0]:
dx = torch.empty(batch_dim, nblocks, p, device=x.device, dtype=x.dtype)
dx = torch.bmm(dout_reshaped, weight.conj(),
out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n)
dx = torch.empty(seq_dim, nblocks, blk_in, device=x.device, dtype=x.dtype)
dx = (
torch.bmm(dout_reshaped, weight.conj(), out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n)
)
if ctx.needs_input_grad[1]:
x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1)
x_reshaped = x.reshape(seq_dim, nblocks, blk_in).transpose(0, 1)
dweight = torch.bmm(dout_reshaped.transpose(-1, -2), x_reshaped.conj())
return dx, dweight



blockdiag_multiply = BlockdiagMultiply.apply