Louka Ewington-Pitsos

Wiki Contributions

Comments

Sorted by

It seems like it's possible right?

Interestingly a similar tool run by a very smart guy (https://www.creatorml.com/) is apparently about to shut down, so it might not be a financially sustainable thing to try.

I couldn't find a link to the code in the article so in case anyone else wants to try to replicate I think this is it: https://github.com/HugoFry/mats_sae_training_for_ViTs

Just to close the loop on this one, the official huggingface transformers library just uses a for-loop to achieve MoE. I also implemented a version myself using a for loop and it's much more efficient than either vanilla matrix multiplication or that weird batch matmul I write up there for large latent and batch sizes.

wait a minute... could you just...

you don't just literally do this do you?

input = torch.tensor([
    [1, 2],
    [1, 2],
    [1, 2],
]) # (bs, input_dim)


enc_expert_1 = torch.tensor([
    [1, 1, 1, 1],
    [1, 1, 1, 1],

])
enc_expert_2 = torch.tensor([
    [3, 3, 0, 0],
    [0, 0, 2, 0],
])



dec_expert_1 = torch.tensor([
    [ -1, -1],
    [ -1, -1],
    [ -1, -1],
    [ -1, -1],
])

dec_expert_2 = torch.tensor([
    [-10, -10,],
    [-10, -10,],
    [-10, -10,],
    [-10, -10,],

])

def moe(input, enc, dec, nonlinearity):
    input = input.unsqueeze(1)
    latent = torch.bmm(input, enc)

    recon = torch.bmm(nonlinearity(latent, dec))

    return recon.squeeze(1), latent.squeeze(1)


# not this! some kind of actual routing algorithm, but you end up with something similar
enc = torch.stack([enc_expert_1, enc_expert_2, enc_expert_1])
dec = torch.stack([dec_expert_1, dec_expert_2, dec_expert_1])

nonlinearity = torch.nn.ReLU()
recons, latent = moe(input, enc, dec, nonlinearity)

This must in some way be horrifically inefficient, right?

Can I ask what you used to implement the MOE routing? Did you use megablocks? I would love to expand on this research but I can't find any straightforward implementation of efficient pytorch MOE routing online.

Do you simply iterate over each max probability expert every time you feed in a batch? 

This is dope, thank you for your service.  Also, can you hit us with your code on this one? Would love to reproduce.