If you haven't already, take a look at this post by johnswentworth to understand what this is all about: https://www.lesswrong.com/posts/KnPN7ett8RszE79PH/demons-in-imperfect-search

The short version is that while systems that use perfect search, such as AIXI, have many safety problems, a whole new set of problems arises when we start creating systems that are not perfect searchers. Patterns can form that exploit the imperfect nature of the search function to perpetuate themselves. johnswentworth refers to such patterns as "demons".

After reading that post I decided to see if I could observe demon formation in a simple model: gradient descent on a not-too-complicated mathematical function. It turns out that even in this very simplistic case, demon formation can happen. Hopefully this post will give people an example of demon formation where the mechanism is simple and easy to visualize.

Model

The function we try to minimize using gradient descent is called the loss function. Here it is:

Let me explain what some of the parts of this loss mean. Each function is periodic with period 2π in every component of . I decided in this case to make my splotch functions out of a few randomly chosen sine waves added together.

is chosen to be a small number so in any local region, will look approximately periodic: A bunch of hills repeating over and over again with period 2π across the landscape. But over large enough distances, the relative weightings of various splotches do change. Travel a distance of 20π in the direction, and will be a larger component of the repeating pattern than it was before. This allows for selection effects.

The term means that the vector mainly wants to increase its component. But the splotch functions can also direct its motion. A splotch function might have a kind of ridge that directs some of the motion into other components. If tends to direct motion in such a way that , increases, then it will be selected for, becoming stronger and stronger as time goes on.

Results

I used ordinary gradient descent, with a constant step size, and with a bit of random noise added in. Figure 1 shows the value of x0 as a function of time, while figure 2 shows the values of x1,x2,…x16 as a function of time.

Fig 1:

Fig 2:

There are three phases to the evolution: In the first, increases steadily, and the other coordinates wander around more or less randomly. In the second phase, a self-reinforcing combination of splotches (a "demon") takes hold and amplifies itself drastically, feeding off the large gradient. Finally, this demon becomes so strong that the search gets stuck in a local valley and further progress stops. The first phase is more or less from 0 to 2500 steps. The second phase is between 2500 steps and 4000 steps, though slowing down after 3500. The final phase starts at 4000 steps, and likely continues indefinitely.

Now that I have seen demons arise in such a simple situation, it makes me wonder how commonly the same thing happens in the training of deep neural networks. Anyways, hopefully this is a useful model for people who want to understand the mechanisms behind the whole "demons in imperfect search" thing more clearly. It definitely helped me, at least.

Update: The code is now up here: https://github.com/DaemonicSigil/tessellating-hills

New Comment
18 comments, sorted by Click to highlight new comments since:

Now this is one of the more interesting things I've come across.

I fiddled around with the code a bit and was able to reproduce the phenomenon with DIMS = 1, making visualisation possible:

Behold!

Here's the code I used to make the plot:

import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

DIMS = 1   # number of dimensions that xn has
WSUM = 5    # number of waves added together to make a splotch
EPSILON = 0.10 # rate at which xn controlls splotch strength
TRAIN_TIME = 5000 # number of iterations to train for
LEARN_RATE = 0.2   # learning rate
MESH_DENSITY = 100 #number of points ot plt in 3d mesh (if applicable)

torch.random.manual_seed(1729)

# knlist and k0list are integers, so the splotch functions are periodic
knlist = torch.randint(-2, 3, (DIMS, WSUM, DIMS)) # wavenumbers : list (controlling dim, wave id, k component)
k0list = torch.randint(-2, 3, (DIMS, WSUM))       # the x0 component of wavenumber : list (controlling dim, wave id)
slist = torch.randn((DIMS, WSUM))                # sin coefficients for a particular wave : list(controlling dim, wave id)
clist = torch.randn((DIMS, WSUM))                # cos coefficients for a particular wave : list (controlling dim, wave id)

# initialize x0, xn
x0 = torch.zeros(1, requires_grad=True)
xn = torch.zeros(DIMS, requires_grad=True)

# numpy arrays for plotting:
x0_hist = np.zeros((TRAIN_TIME,))
xn_hist = np.zeros((TRAIN_TIME, DIMS))
loss_hist = np.zeros(TRAIN_TIME,)


def model(xn,x0):
    wavesum = torch.sum(knlist*xn, dim=2) + k0list*x0
    splotch_n = torch.sum(
            (slist*torch.sin(wavesum)) + (clist*torch.cos(wavesum)),
            dim=1)
    foreground_loss = EPSILON * torch.sum(xn * splotch_n)
    return foreground_loss - x0

# train:
for t in range(TRAIN_TIME):

    print(t)
    loss = model(xn,x0)
    loss.backward()
    with torch.no_grad():
        # constant step size gradient descent, with some noise thrown in
        vlen = torch.sqrt(x0.grad*x0.grad + torch.sum(xn.grad*xn.grad))
        x0 -= LEARN_RATE*(x0.grad/vlen + torch.randn(1)/np.sqrt(1.+DIMS))
        xn -= LEARN_RATE*(xn.grad/vlen + torch.randn(DIMS)/np.sqrt(1.+DIMS))
    x0.grad.zero_()
    xn.grad.zero_()
    x0_hist[t] = x0.detach().numpy()
    xn_hist[t] = xn.detach().numpy()
    loss_hist[t] = loss.detach().numpy()

plt.plot(x0_hist)
plt.xlabel('number of steps')
plt.ylabel('x0')
plt.show()
for d in range(DIMS):
    plt.plot(xn_hist[:,d])
plt.xlabel('number of training steps')
plt.ylabel('xn')
plt.show()

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(x0_hist,xn_hist[:,0],loss_hist)

#plot loss landscape
if DIMS == 1:
    x0_range = np.linspace(np.min(x0_hist),np.max(x0_hist),MESH_DENSITY)
    xn_range = np.linspace(np.min(xn_hist),np.max(xn_hist),MESH_DENSITY)
    x,y = np.meshgrid(x0_range,xn_range)
    z = np.zeros((MESH_DENSITY,MESH_DENSITY))
    with torch.no_grad():
        for i,x0 in enumerate(x0_range):
            for j,xn in enumerate(xn_range):
                z[j,i] = model(torch.tensor(xn),torch.tensor(x0)).numpy()
    ax.plot_surface(x,y,z,color='orange',alpha=0.3)
ax.set_title("loss")
plt.show()

That's very cool, thanks for making it. At first I was worried that this meant that my model didn't rely on selection effects. Then I tried a few different random seeds, and some, like 1725, didn't show demon-like behaviour. So I think we're still good.

Hmm, the inherent 1d nature of the visualization kinda makes it difficult to check for selection effects. I'm not convinced that's actually what's going on here. 1725 is special because the ridges of the splotch function are exactly orthogonal to x0. The odds of this happening probably go down exponentially with dimensionality. Furthermore, with more dakka, one sees that the optimization rate drops dramatically after ~15000 time steps, and may or may not do so again later. So I don't think this proves selection effects are in play. An alternative hypothesis is simply that the process gets snagged by the first non-orthogonal ridge it encounters, without any serous selection effects coming into play.

Here is the code for people who want to reproduce these results, or just mess around:

import torch
import numpy as np
import matplotlib.pyplot as plt

DIMS = 16   # number of dimensions that xn has
WSUM = 5    # number of waves added together to make a splotch
EPSILON = 0.0025 # rate at which xn controlls splotch strength
TRAIN_TIME = 5000 # number of iterations to train for
LEARN_RATE = 0.2   # learning rate

torch.random.manual_seed(1729)

# knlist and k0list are integers, so the splotch functions are periodic
knlist = torch.randint(-2, 3, (DIMS, WSUM, DIMS)) # wavenumbers : list (controlling dim, wave id, k component)
k0list = torch.randint(-2, 3, (DIMS, WSUM))       # the x0 component of wavenumber : list (controlling dim, wave id)
slist = torch.randn((DIMS, WSUM))                # sin coefficients for a particular wave : list(controlling dim, wave id)
clist = torch.randn((DIMS, WSUM))                # cos coefficients for a particular wave : list (controlling dim, wave id)

# initialize x0, xn
x0 = torch.zeros(1, requires_grad=True)
xn = torch.zeros(DIMS, requires_grad=True)

# numpy arrays for plotting:
x0_hist = np.zeros((TRAIN_TIME,))
xn_hist = np.zeros((TRAIN_TIME, DIMS))

# train:
for t in range(TRAIN_TIME):
    ### model: 
    wavesum = torch.sum(knlist*xn, dim=2) + k0list*x0
    splotch_n = torch.sum(
            (slist*torch.sin(wavesum)) + (clist*torch.cos(wavesum)),
            dim=1)
    foreground_loss = EPSILON * torch.sum(xn * splotch_n)
    loss = foreground_loss - x0
    ###
    print(t)
    loss.backward()
    with torch.no_grad():
        # constant step size gradient descent, with some noise thrown in
        vlen = torch.sqrt(x0.grad*x0.grad + torch.sum(xn.grad*xn.grad))
        x0 -= LEARN_RATE*(x0.grad/vlen + torch.randn(1)/np.sqrt(1.+DIMS))
        xn -= LEARN_RATE*(xn.grad/vlen + torch.randn(DIMS)/np.sqrt(1.+DIMS))
    x0.grad.zero_()
    xn.grad.zero_()
    x0_hist[t] = x0.detach().numpy()
    xn_hist[t] = xn.detach().numpy()

plt.plot(x0_hist)
plt.xlabel('number of steps')
plt.ylabel('x0')
plt.show()
for d in range(DIMS):
    plt.plot(xn_hist[:,d])
plt.xlabel('number of training steps')
plt.ylabel('xn')
plt.show()

Very nice work. The graphs in particular are quite striking.

I sat down and thought for a bit about whether that objective function is actually a good model for the behavior we're interested in. Twice I thought I saw an issue, then looked back at the definition and realized you'd set up the function to avoid that issue. Solid execution; I think you have actually constructed a demonic environment.

This is awesome, thanks!

So, to check my understanding: You have set up a sort of artificial feedback loop, where there are N overlapping patterns of hills, and each one gets stronger the farther you travel in a particular dimension/direction. So if one or more of these patterns tends systematically to push the ball in the same direction that makes it stronger, you'll get a feedback loop. And then there is selection between patterns, in the sense that the pattern which pushes the strongest will beat the ones that push more weakly, even if both have feedback loops going.

And then the argument is, even though these feedback loops were artificial / baked in by you, in "natural" search problems there might be a similar situation... what exactly is the reason for this? I guess my confusion is in whether to expect real life problems to have this property where moving in a particular direction strengthens a particular pattern. One way I could see this happening is if the patterns are themselves pretty smart, and are able to sense which directions strengthen them at any given moment. Or it could happen if, by chance, there happens to be a direction and a pattern such that the pattern systematically pushes in that direction and the direction systematically strengthens that pattern. But how likely are these? I don't know. I guess your case is a case of the second, but it's rigged a bit, because of how you built in the systematic-strengthening effect.

Am I following, or am I misunderstanding?

Thanks, and your summary is correct. You're also right that this is a pretty contrived model. I don't know exactly how common demons are in real life, and this doesn't really shed much light on that question. I mainly thought that it was interesting to see that demon formation was possible in a simple situation where one can understand everything that is going on.

I have the same confusion

[-][anonymous]Ω250

Hi, thanks for sharing and experimentally trying out the theory in the previous post! Super cool.

Do you have the code for this up anywhere?

I'm also a little confused by the training procedure. Are you just instantiating a random vector and then doing GD with regards to the loss function you defined? Do the charts show the loss averaged over many random vectors (and function variants)?

Thanks. I initially tried putting the code in a comment on this post, but it ended up being deleted as spam. It's now up on github: https://github.com/DaemonicSigil/tessellating-hills It isn't particularly readable, for which I apologize.

The initial vector has all components set to 0, and the charts show the evolution of these components over time. This is just for a particular run, there isn't any averaging. x0 gets its own chart, since it changes much more than the other components. If you want to know how the loss varies with time, you can just flip figure 1 upside down to get a pretty good proxy, since the splotch functions are of secondary importance compared to the -x0 term.

[-]habrykaΩ150

Oops, sorry for that. I restored your original comment.

In very high-dimensional spaces, getting stuck in local minima is harder. Do the same results happen in, say, 10,000 dimensions? If so, what's the relationship between the number of dimensions and the time (or progress in x_0) before getting stuck in a daemon? If not, is there another function that exhibits easily found daemons in 10,000 dimensions?

Bit late, but running the same experiment with 1000 dimensions instead of 16, and 10k steps instead of 5k gives plot1 plot2

Which appears to be on the way to a minima. though I'm unsure if I should tweak hparams when scaling up this much. Trying with other optimizers would be interesting too, but I think I've got nerdsniped by this too much already... Code is here.

You're totally not obligated to do this, but I think it might be cool if you generated a 3D picture of hills representing your loss function-- I think it would make the intuition for what's going on clearer.

[-]VaniverΩ120

Finally, this demon becomes so strong that the search gets stuck in a local valley and further progress stops.

I don't see why the gradient with respect to x0 ever changes, and so am confused about why it would ever stop increasing in the x0 direction. Does this have to do with using a fixed step size instead of learning rate? 

[Edit: my current thought is that it looks like there's periodic oscillation in the 3rd phase, which is probably an important part of the story; the gradient is mostly about how to point at the center of that well, which means it orbits that center, and x0 progress grinds to a crawl because it's a small fraction of the overall gradient, whereas it would continue at a regular pace if it were a constant learning rate instead, I think.]

Also, did you use any regularization? [Edit: if so, the decrease in response to x0 might actually be present in a one-dimensional version of this, suggesting it's a very different story.]

I don't see why the gradient with respect to x0 ever changes, and so am confused about why it would ever stop increasing in the x0 direction.

Looks like the splotch functions are each a random mixture of sinusoids in each direction - so each will have some variation along . The argument of is all of , not just .

[-]VaniverΩ120

Ah, that'd do it too.

No regularization was used.

I also can't see any periodic oscillations when I zoom in on the graphs. I think the wobbles you are observing in the third phase are just a result of the random noise that is added to the gradient at each step.