Wednesday, June 8, 2011

A hierarchical Bayesian model of pond scum

This week I am working with one of my colleagues, the extraordinary biologist Jean Huang, on an interesting problem related to bioinformatics.  This project is my first attempt at implementing a hierarchical Bayesian model in Python, and it has been a pleasant surprise.  If you are familiar with tree-processing algorithms, it's pretty straightforward.  First the likelihoods propagate up the tree, then the updates propagate down.

Here is the background: Prof Huang and her students collect samples of pond water and culture the bacteria they find under narrow-spectrum artificial light, trying to find species that photosynthesize at different wavelengths.

To identity the species in each culture, they amplify the 16S ribosomal RNA gene, send it out to be sequenced, and then look up the sequence in a database like greengenes.  The result is either the name of a species that has been cultured, the genus of an uncultured species, or an unknown genus.

They usually start by identifying 15 samples from each culture.  For example, one of their cultures yielded 9 samples of one species, three of another, and one each of three more.

Looking at the results for a given culture, biologists would like to be able to answer these questions:

1) How many more species are there likely to be?
2) What is the likely prevalence of each species?
3) If we test m additional samples, how many new species are we likely to find?
4) Given a limited budget, which cultures warrant additional sampling?

To answer these questions, I implemented a hierarchical Bayesian model where the top level is "How many species are there?" and the second level is "Given that there are n species, what is the prevalence of each species?"  The prevalence of a species is its proportion of the population.

Knowing how many species there are helps model their prevalences.  In the simplest case, if we know that there is only 1 species, the prevalence of that species must be 100%.  If we know there are 2 species, I assume that the distribution of prevalences for both species is uniform from 0 to 100, so the expected prevalence for both is 50%, which makes sense.

If we know there are 3 species, it is less obvious what the prior distribution should be; the Beta distribution is a natural choice because

1) If we know there are k species, we can use Beta(1, k-1) as a prior, and the expected value is 1/k, which makes sense.  For k=2, the result is a uniform distribution, so that makes sense, too.

2) For this scenario, the Beta distribution is a conjugate prior, which means that after an update the posterior is also a Beta distribution.

We can represent a Beta distribution with an object that keeps track of the parameters:


class Beta:
    def __init__(self, yes, no):
        """Initializes a Beta distribution.
        yes and no are the number of 
        successful/unsuccessful trials.
        """
        self.alpha = yes+1
        self.beta = no+1

Update couldn't be easier.

    def Update(self, yes, no):
        """Updates a Beta distribution."""
        self.alpha += yes
        self.beta += no

At the top level ("How many species are there?") the prior is a uniform distribution from 1 to 20, at least for now.   I could replace that with a distribution that reflects more of the domain knowledge biologists have about the experiment; for example, based on how the cultures are processed, it would be rare to find more than 10 species with any substantial prevalence.

Now that the priors are in place, how do we update?  Here's the top level:


    def Update(self, evidence):
        """Updates based on observing a given taxon."""
        for hypo, prob in self.hypos.Items():
            likelihood = hypo.Likelihood(evidence)
            if likelihood:
                self.hypos.Mult(hypo, likelihood)
            else:
                # if a hypothesis has been ruled out, remove it
                self.hypos.Remove(hypo)


        self.hypos.Normalize()

The evidence is a string that indicates which taxon (species or genus) was observed.  self.hypos is a Pmf that maps from a hypo to its probability (Pmf is provided by the thinkstats package, a collection of modules used in my book, Think Stats).  Each hypo is a hypothesis about how many taxons there are.  We ask each hypo to compute the likelihood of the evidence, then update hypos accordingly.  If we see N different taxons, all hypotheses with k<N are eliminated.

Having updated the top level, we pass the evidence down to the lower level and update each hypothesis.

        # update the hypotheses
        for hypo, prob in self.hypos.Items():
            hypo.Update(evidence)

To compute the likelihood of the evidence under a hypothesis, we ask "What is the probability of observing this taxon, given our current belief about the prevalence of each taxon?"  There are two cases:

1) If the observed taxon is one we have seen before, we just look up the current Beta distribution and compute its mean:

    def Mean(self):
        """Computes the mean of a Beta distribution."""
        return self.alpha / (self.alpha + self.beta)

2) If we are seeing a taxon for the first time, we get the likelihoods for all unseen taxons and add them up.

To update the lower level distributions, we loop through the Beta distributions that represent the prevalences, and update them with either a hit or a miss.

    def Update(self, evidence):
        """Updates each Beta dist with new evidence.

        evidence is a taxon
        """
        for taxon, dist in self.taxa.iteritems():
            if taxon == evidence:
                dist.Update(1, 0)
            else:
                dist.Update(0, 1)

And that's it.  After doing a series of updates, the top level is the posterior distribution of the number of taxons, and each hypothesis at the lower level is a set of Beta distributions that represent the prevalences.

Let's look at an example.  Suppose we see 11 of one taxon, 4 of another, and 2 of a third.  The posterior distribution of the number of taxons is:


The probability is 47% that there are only 3 taxons, 28% that there are 4, and 13% that there are 5.

By summing over the hypotheses, we can generate distributions for the prevalence of each taxon:


The median prevalence of taxon A is 58%; the credible interval is (39%, 75%).  For taxons B and C the medians are 23% and 13%.  For taxons D, E, and F, there is a high probability that the prevalence is 0, because these taxons have not been observed.

To predict the number of new taxons we might find by sequencing additional samples, I use Monte Carlo simulation.  Here is the kernel of the algorithm:

        for i in range(m):
            taxon = self.GenerateTaxon()
            taxons.append(taxon)
            meta.Update(taxon)

        curve = MakeCurve(taxons)

m is the number of samples to simulate.  Each time through the loop, we generate a random taxon and then update the hierarchy.  The result is a rarefaction curve, which is the number of taxons we observe as a function of the number of samples.

To generate a random taxon, we take advantage of the recursive structure of the hierarchy; that is, we ask each hypothesis to choose a random taxon, and then choose among them in accordance with the probability for each hypothesis:

    def GenerateTaxon(self):
        """Chooses a random taxon."""
        pmf = Pmf.Pmf()
        for hypo, prob in sorted(self.hypos.Items()):
            taxon = hypo.GenerateTaxon()
            pmf.Incr(taxon, prob)
        return pmf.Random()

The algorithm the hypotheses use to generate taxons is ugly, so I will spare you the details, but ultimately it depends on the ability to generate Beta variates, which is provided in Python's random module.

Given a sequence of taxons (real or simulated), it is easy to generate a rarefaction curve:

def MakeCurve(sample):
    """Makes a rarefaction curve for the given sample."""
    s = set()
    curve = []
    for i, taxon in enumerate(sample):
        s.add(taxon)
        curve.append((i+1, len(s)))
    return curve

To show what additional samples might yield, I generate 100 rarefaction curves and plot them:


The first 17 points are random permutations of the observed data, so after 17 samples, we have seen 3 taxa on every curve.  The last 15 points are based on simulated data.  The lines are shifted slightly so they don't overlap; that way you can eyeball high-probability paths.

Even after m=15 additional samples, there is a good chance that we don't see any more taxons -- but there is a reasonable chance of seeing 1-2, and a small chance of seeing 3.

Finally, we estimate the probability of seeing additional taxons as a function of the number of addition samples:


With 5 additional samples, the chance of seeing at least one additional taxon is 24%; with 10 samples it's 40%, and with 15 it's 53%.

These estimates should help biologists allocate their budget for additional samples is order to minimize the chance of missing an important low-prevalence taxon.

Next steps: I am working with Prof. Huang and her students to report the results most usefully; then we are thinking about distributing the code or deploying it as a web app.


3 comments:

  1. Thanks for the post. Would you mind adding the complete code? It's a bit confusing what's going where, what calls what, which methods are part of the same class, etc., esp. if we're not already familiar with the package.

    ReplyDelete
  2. Yes, the code is here: http://code.google.com/p/thinkstats/source/browse/trunk/workspace.thinkstats/ThinkStats/rarefaction.py

    It depends on a few of the other modules in the same folder, and on matplotlib.

    But please keep in mind that it is still a work in progress.

    ReplyDelete
  3. This comment has been removed by a blog administrator.

    ReplyDelete