Thanks. It's similar in one sense, but (if I'm reading the paper right) a key difference is that in the MAML examples, the ordering of the meta-level and object level training is such that you still wind up optimizing hard for a particular goal. The idea here is that the two types of training function in opposition, as a control system of sorts, such that the meta-level training should make the model perform worse at the narrow type of task it was trained on.
That said, for sure, the types of distribution shift thing is an issue. It seems like this meta-level bias might be less bad than at the object level, but I have no idea.
Note: I think there's a decent chance that the idea I describe is misguided, redundant, naive. My apologies if that is the case, and please feel free to point me towards existing, related writing or research.
Thanks to Peter Barnett and Justis Mills for feedback on a draft of this post. It was inspired by Eliezer's Lethalities post and Zvi's response.
Central idea: can we train AI to generalize out of distribution?
I'm thinking, for example, of an algorithm like the following:
And, of course, we can keep on adding piling layers on.
A few notes
Anyone here know Python?
My hands-on experience with ML extends to linear regression in R and not an inch more, so I'm probably not the best person to test this theory out. I've heard some LWers know a bit of Python, though.
If that's you, I'd be fascinated and thankful to see if you can implement this idea using whatever data and structure you think would work best, and would be happy to collaborate in whatever capacity I can.
Appendix: a few brief comments (from someone with much more domain knowledge than me) and responses (from me):
Comment
Is this just the same as training it on this more complex task (but only doing one big update at the end, rather than doing lots of small updates)?
Response (which may help to clarify why I believe the idea might work)
I don't think so, because the parameters don't change/update/improve between each of those independent tests. Like GPT-3 in some sense has a "memory" of reading Romeo and Juliet, but that's only because its parameters updated as a result of seeing the text.
But also I think my conception depends on the system having "layers" of parameters corresponding to each layer of training.
So train on simple English-->only "Simple English word generation" parameters are allowed to change...but then you tell it how well it did at generalizing out of distribution, and now only its "meta level 1 generalization" parameters are allowed to change.
Then you do the whole thing again but with German text, and its "Meta level 1 generalization" parameters are allowed to change again using SGD or whatever. If this works, it will be the reason why it can do well at advanced Hindi text without ever having read advanced Hindi.
Treat this whole process as the object level, and then it updates/improves "meta level 2 generalization" parameters.
Comment:
This looks vaguely like curriculum learning, which apparently doesn't really work in LLMs https://arxiv.org/abs/2108.02170, I think a similar experiment would be like train on simple+advanced text for English, French, Mandarin etc, but only simple Hindi, and then see if it can do complex Hindi.
Response
I think that's a pretty different thing because there are no meta level parameters. Seems like fundamentally just a flavor of normal RL
Or do pretraining with English, French, Mandarin, and Hindi, but only do fine tuning with English, French, Mandarin, and see if it can then do the tasks it was fine tuned for in Hindi.
My prediction: it learns to generalize a bit (the scores on the novel Hindi tasks are higher than if there was no fine tuning with the other languages) but worse than the other languages generalize. As the models are scaled up, this 'generalization gap' gets smaller.
Seems like this might depend on the relative scaling of different meta level parameters (which I described above)?
Like for example whenever you scale the # of object level params by a factor of 2, you have to scale the number of nth meta level parameters by 2^(n+1).