Wednesday, February 22, 2017

LSTM musings

This week I'm reading about a type of artificial neural network called a Long Short Term Memory Networks (LSTMs for short), which I've heard about a number of times but never actually learned what they are beyond the name. The tutorial I'm reading comes highly recommended by my friend Google (I searched LSTM tutorial).

The details are useful, and the idea of explicitly having memory management in an RNN makes a lot of sense. The basic idea is that if you want your artificial neural network to have memory of stuff that happened in the past, then you should specifically arrange it so that it has that property. Interestingly, just taking some outputs and plugging them back in as inputs (a standard RNN) is not able to learn to represent long-term memories. From my perspective, the idea to structure your network to have a desired property is the same rough idea as how Convolutional Neural Networks apply the same operation over all of visual space, since we expect that vision in the middle of the image should be pretty similar to vision on the edges.

Anyway, the tutorial does a great job of explaining how they work, so I won't repeat it here. I'm curious what advances have happened since 2015 when it was written. I know people have already tried applying 'attention' to networks, as mentioned in the conclusion, though I don't know how well those work or how well they mimic the brain's version of attention. Certainly no one has found an LSTM module in the brain yet, so I doubt anyone adding something on to these networks is aiming too much for biological plausibility.

One thing I struggle with when dealing with these kinds of learning algorithms is that its very difficult to tell (without playing around with them for a while) which kinds of changes to the architecture are going to make substantial changes in how the network behaves. The tutorial has a few variants that seem to conceptually allow the network to do the same thing. Do these differences matter substantially? It's likely they do, but it's certainly not clear how exactly. I'd love to know if anyone has any decent ways to figure an answer to that kind of question. Is the answer just a shotgun approach? Throw your algorithm/structure at everything and see what works? Are there any good ways to visualize the loss function that might provide insights? For example, are there classes of problems where the loss function is a spiky mess full of local minima?

I'll continue to ponder all of this, but I'd love to get some input if you have any to give.

Wednesday, February 8, 2017

Interesting paper on why you have to make a good null hypothesis

Today I'll be talking about this paper, titled "Spike-Centered Jitter Can Mistake Temporal Structure." This is a topic that first got me interested in my lab because I felt that I had something to add to the neuroscience community. This is a case where mathematical rigor is important to get right and the neuroscience community at large is not as focused on mathematical rigor as I am (this isn't a criticism and one day I'll write a post about why). Before I get to what the paper is about, let me give some background.

Background


A while ago, we (meaning other scientists that aren't me at all because I wasn't alive yet) figured out that neurons tend to respond to particular stimuli and are organized in a particular way. We generally describe what stimuli activate a neuron using a Receptive field. That's generally driven by the part of the network that starts with stimulus (e.g. light hitting your eyes) and goes further into the brain (e.g. visual cortex). This direction (stimulus to brain, and further into the brain) is usually called the feedforward direction.

Next we started asking about what the brain does with that feedforward information. Many many of the connections in the brain aren't feedforward at all. There are lateral connections within an area, feedback loops that span many different areas, and everything in between. If we want to know how the brain works, we better figure out what all those connections are doing. And to figure out what they're doing, we better figure out which connections are there and which ones aren't. The "easy" way would be to literally look at all of the connections from one neuron to another. This is really hard because there are about 1 billion neurons and the connections are tiny and/or long. So what we want is a way to record from a pair of neurons and see if they're connected, and if we can make it fancier we want to record from tons of neurons and figure out the wiring diagram.

We started doing just that by looking at the number of times neurons fire at the same time (this is called synchrony), or at some specific delay from the other. If they're firing at the same time, they likely have some common input, and if they have a delay, one may be causing the other to fire. Well that seems to solve the problem, but then you get the age-old scientific question: what's the null hypothesis? And now things get tricky again. You could say "I expect a flat firing rate". Then you can count the number of spikes and compute how many you would expect by chance pretty easily. But in most experiments, you don't see flat firing rates. Even worse, you can see firing rates that co-vary between neurons even though they aren't connected at all. This can happen because the stimuli that activate the neurons are being presented at the same time, or because of waves of electrical activity that propagate across cortex may have little to do with direct connections between neurons.

As computers have gotten faster, we've been able to come up with more complicated null hypotheses. A good one (I thought so until I read this stuff) is the spike-centered jitter hypothesis. It says that if the null hypothesis is true, the number of synchronous spikes should be the same if we randomly shift the spikes around by a little bit. We can simulate this pretty easily by making fake spike trains where the spikes have been shifted around and counting the number of times the spikes still line up.

On to the New Stuff


It turns out there's a problem with that. The problem is that if you look at the null hypothesis, it's dependent on the data. That's could be okay if dependence is done correctly but here it isn't. The problem is that if you had a whole bunch of these jittered spike trains, you could pick out the original as the one where each of the spikes is in the center of where jittered spikes fall. My first reaction to this was to say "is this just a quirk of math, or do I actually need to care about it?"

Now the authors have shown me the answer. Everyone should care. What they do in this paper is they show a number of surprisingly reasonable cases where the spike-centered jitter hypothesis would make scientists detect synchrony when there isn't any connection between the spike trains. The best example they give is of case where one of the spike trains is a rhythmic firing, which does actually happen, and the other one is just a few random spikes. They show that if there are enough random spikes, you can get fooled by the spike-centered jitter hypothesis for any p-value threshold you can think of.

They also show that their alternative method (which is already mathematically proven but takes a bit more computing power and thought) avoids all of these problems.

Take-Aways


The first take-away from this paper is obvious: don't use the spike-centered jitter hypothesis. Just don't do it. Use interval jitter instead, or better yet, use my algorithm to implement it! The second take-away is that being careful about your null hypothesis is important. If it has any dependence on your original data, scrutinize it a second or third time to make sure it doesn't fall into the trap that the spike-centered jitter hypothesis does. If still unsure, check with a statistician.