Thursday, February 16, 2017

A nice Bayes theorem problem: medical testing

On these previous post about my favorite Bayes theorem problems, I got the following comment from a reader named Riya:

I have a question. Exactly 1/5th of the people in a town have Beaver Fever . There are two tests for Beaver Fever, TEST1 and TEST2. When a person goes to a doctor to test for Beaver Fever, with probability 2/3 the doctor conducts TEST1 on him and with probability 1/3 the doctor conducts TEST2 on him. When TEST1 is done on a person, the outcome is as follows: If the person has the disease, the result is positive with probability 3/4. If the person does not have the disease, the result is positive with probability 1/4. When TEST2 is done on a person, the outcome is as follows: If the person has the disease, the result is positive with probability 1. If the person does not have the disease, the result is positive with probability 1/2. A person is picked uniformly at random from the town and is sent to a doctor to test for Beaver Fever. The result comes out positive. What is the probability that the person has the disease? 

I think this is an excellent question, so I am passing it along to the readers of this blog.  One suggestion: you might want to use my world famous Bayesian update worksheet.

Hint: This question is similar to one I wrote about last year.  In that article, I started with a problem that was underspecified; it took a while for me to realize that there were several ways to formulate the problem, with different answers.

Fortunately, the problem posed by Riya is completely specified; it is an example of what I called Scenario A, where there are two tests with different properties, and we don't know which test was used.

There are several ways to proceed, but I recommend writing four hypotheses that specify the test and the status of the patient:

TEST1 and sick
TEST1 and not sick
TEST2 and sick
TEST2 and not sick

For each of these hypotheses, it is straightforward to compute the prior probability and the likelihood of a positive test.  From there, it's just arithmetic.

Here's what it looks like using my world famous Bayesian update worksheet:

(Now with more smudges because I had an arithmetic error the first time.  Thanks, Ben Torvaney, for pointing it out.)

After the update, the total probability that the patient is sick is 10/26 or about 38%.  That's up from the prior, which was 1/5 or 20%.  So the positive test is evidence that the patient is sick, but it is not very strong evidence.

Interestingly, the total posterior probability of TEST2 is 12/26 or about 46%.  That's up from the prior, which was 33%.  So the positive test provides some evidence that TEST2 was used.


Monday, January 16, 2017

Last batch of notebooks for Think Stats

Getting ready to teach Data Science in the spring, I am going back through Think Stats and updating the Jupyter notebooks.  Each chapter has a notebook that shows the examples from the book along with some small exercises, with more substantial exercises at the end.

If you are reading the book, you can get the notebooks by cloning this repository on GitHub, and running the notebooks on your computer.

Or you can read (but not run) the notebooks on GitHub:

Chapter 13 Notebook (Chapter 13 Solutions)
Chapter 14 Notebook (Chapter 14 Solutions)

I am done now, just in time for the semester to start, tomorrow! Here are some of the examples from Chapter 13, on survival analysis:


Survival analysis

If we have an unbiased sample of complete lifetimes, we can compute the survival function from the CDF and the hazard function from the survival function.
Here's the distribution of pregnancy length in the NSFG dataset.
In [2]:
import nsfg

preg = nsfg.ReadFemPreg()
complete = preg.query('outcome in [1, 3, 4]').prglngth
cdf = thinkstats2.Cdf(complete, label='cdf')
The survival function is just the complementary CDF.
In [3]:
import survival

def MakeSurvivalFromCdf(cdf, label=''):
    """Makes a survival function based on a CDF.

    cdf: Cdf
    
    returns: SurvivalFunction
    """
    ts = cdf.xs
    ss = 1 - cdf.ps
    return survival.SurvivalFunction(ts, ss, label)
In [4]:
sf = MakeSurvivalFromCdf(cdf, label='survival')
In [5]:
print(cdf[13])
print(sf[13])
0.13978014121
0.86021985879
Here's the CDF and SF.
In [6]:
thinkplot.Plot(sf)
thinkplot.Cdf(cdf, alpha=0.2)
thinkplot.Config(loc='center left')
And here's the hazard function.
In [7]:
hf = sf.MakeHazardFunction(label='hazard')
print(hf[39])
0.676706827309
In [8]:
thinkplot.Plot(hf)
thinkplot.Config(ylim=[0, 0.75], loc='upper left')

Age at first marriage

We'll use the NSFG respondent file to estimate the hazard function and survival function for age at first marriage.
In [9]:
resp6 = nsfg.ReadFemResp()
We have to clean up a few variables.
In [10]:
resp6.cmmarrhx.replace([9997, 9998, 9999], np.nan, inplace=True)
resp6['agemarry'] = (resp6.cmmarrhx - resp6.cmbirth) / 12.0
resp6['age'] = (resp6.cmintvw - resp6.cmbirth) / 12.0
And the extract the age at first marriage for people who are married, and the age at time of interview for people who are not.
In [11]:
complete = resp6[resp6.evrmarry==1].agemarry.dropna()
ongoing = resp6[resp6.evrmarry==0].age
The following function uses Kaplan-Meier to estimate the hazard function.
In [12]:
from collections import Counter

def EstimateHazardFunction(complete, ongoing, label='', verbose=False):
    """Estimates the hazard function by Kaplan-Meier.

    http://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator

    complete: list of complete lifetimes
    ongoing: list of ongoing lifetimes
    label: string
    verbose: whether to display intermediate results
    """
    if np.sum(np.isnan(complete)):
        raise ValueError("complete contains NaNs")
    if np.sum(np.isnan(ongoing)):
        raise ValueError("ongoing contains NaNs")

    hist_complete = Counter(complete)
    hist_ongoing = Counter(ongoing)

    ts = list(hist_complete | hist_ongoing)
    ts.sort()

    at_risk = len(complete) + len(ongoing)

    lams = pd.Series(index=ts)
    for t in ts:
        ended = hist_complete[t]
        censored = hist_ongoing[t]

        lams[t] = ended / at_risk
        if verbose:
            print(t, at_risk, ended, censored, lams[t])
        at_risk -= ended + censored

    return survival.HazardFunction(lams, label=label)
Here is the hazard function and corresponding survival function.
In [13]:
hf = EstimateHazardFunction(complete, ongoing)
thinkplot.Plot(hf)
thinkplot.Config(xlabel='Age (years)',
                 ylabel='Hazard')
In [14]:
sf = hf.MakeSurvival()
thinkplot.Plot(sf)
thinkplot.Config(xlabel='Age (years)',
                 ylabel='Prob unmarried',
                 ylim=[0, 1])

Quantifying uncertainty

To see how much the results depend on random sampling, we'll use a resampling process again.
In [15]:
def EstimateMarriageSurvival(resp):
    """Estimates the survival curve.

    resp: DataFrame of respondents

    returns: pair of HazardFunction, SurvivalFunction
    """
    # NOTE: Filling missing values would be better than dropping them.
    complete = resp[resp.evrmarry == 1].agemarry.dropna()
    ongoing = resp[resp.evrmarry == 0].age

    hf = EstimateHazardFunction(complete, ongoing)
    sf = hf.MakeSurvival()

    return hf, sf
In [16]:
def ResampleSurvival(resp, iters=101):
    """Resamples respondents and estimates the survival function.

    resp: DataFrame of respondents
    iters: number of resamples
    """ 
    _, sf = EstimateMarriageSurvival(resp)
    thinkplot.Plot(sf)

    low, high = resp.agemarry.min(), resp.agemarry.max()
    ts = np.arange(low, high, 1/12.0)

    ss_seq = []
    for _ in range(iters):
        sample = thinkstats2.ResampleRowsWeighted(resp)
        _, sf = EstimateMarriageSurvival(sample)
        ss_seq.append(sf.Probs(ts))

    low, high = thinkstats2.PercentileRows(ss_seq, [5, 95])
    thinkplot.FillBetween(ts, low, high, color='gray', label='90% CI')
The following plot shows the survival function based on the raw data and a 90% CI based on resampling.
In [17]:
ResampleSurvival(resp6)
thinkplot.Config(xlabel='Age (years)',
                 ylabel='Prob unmarried',
                 xlim=[12, 46],
                 ylim=[0, 1],
                 loc='upper right')
The SF based on the raw data falls outside the 90% CI because the CI is based on weighted resampling, and the raw data is not. You can confirm that by replacing ResampleRowsWeighted with ResampleRows in ResampleSurvival.

More data

To generate survivial curves for each birth cohort, we need more data, which we can get by combining data from several NSFG cycles.
In [18]:
resp5 = survival.ReadFemResp1995()
resp6 = survival.ReadFemResp2002()
resp7 = survival.ReadFemResp2010()
In [19]:
resps = [resp5, resp6, resp7]
The following is the code from survival.py that generates SFs broken down by decade of birth.
In [20]:
def AddLabelsByDecade(groups, **options):
    """Draws fake points in order to add labels to the legend.

    groups: GroupBy object
    """
    thinkplot.PrePlot(len(groups))
    for name, _ in groups:
        label = '%d0s' % name
        thinkplot.Plot([15], [1], label=label, **options)

def EstimateMarriageSurvivalByDecade(groups, **options):
    """Groups respondents by decade and plots survival curves.

    groups: GroupBy object
    """
    thinkplot.PrePlot(len(groups))
    for _, group in groups:
        _, sf = EstimateMarriageSurvival(group)
        thinkplot.Plot(sf, **options)

def PlotResampledByDecade(resps, iters=11, predict_flag=False, omit=None):
    """Plots survival curves for resampled data.

    resps: list of DataFrames
    iters: number of resamples to plot
    predict_flag: whether to also plot predictions
    """
    for i in range(iters):
        samples = [thinkstats2.ResampleRowsWeighted(resp) 
                   for resp in resps]
        sample = pd.concat(samples, ignore_index=True)
        groups = sample.groupby('decade')

        if omit:
            groups = [(name, group) for name, group in groups 
                      if name not in omit]

        # TODO: refactor this to collect resampled estimates and
        # plot shaded areas
        if i == 0:
            AddLabelsByDecade(groups, alpha=0.7)

        if predict_flag:
            PlotPredictionsByDecade(groups, alpha=0.1)
            EstimateMarriageSurvivalByDecade(groups, alpha=0.1)
        else:
            EstimateMarriageSurvivalByDecade(groups, alpha=0.2)
Here are the results for the combined data.
In [21]:
PlotResampledByDecade(resps)
thinkplot.Config(xlabel='Age (years)',
                   ylabel='Prob unmarried',
                   xlim=[13, 45],
                   ylim=[0, 1])
We can generate predictions by assuming that the hazard function of each generation will be the same as for the previous generation.
In [22]:
def PlotPredictionsByDecade(groups, **options):
    """Groups respondents by decade and plots survival curves.

    groups: GroupBy object
    """
    hfs = []
    for _, group in groups:
        hf, sf = EstimateMarriageSurvival(group)
        hfs.append(hf)

    thinkplot.PrePlot(len(hfs))
    for i, hf in enumerate(hfs):
        if i > 0:
            hf.Extend(hfs[i-1])
        sf = hf.MakeSurvival()
        thinkplot.Plot(sf, **options)
And here's what that looks like.
In [23]:
PlotResampledByDecade(resps, predict_flag=True)
thinkplot.Config(xlabel='Age (years)',
                 ylabel='Prob unmarried',
                 xlim=[13, 45],
                 ylim=[0, 1])

Remaining lifetime

Distributions with difference shapes yield different behavior for remaining lifetime as a function of age.
In [24]:
preg = nsfg.ReadFemPreg()

complete = preg.query('outcome in [1, 3, 4]').prglngth
print('Number of complete pregnancies', len(complete))
ongoing = preg[preg.outcome == 6].prglngth
print('Number of ongoing pregnancies', len(ongoing))

hf = EstimateHazardFunction(complete, ongoing)
sf1 = hf.MakeSurvival()
Number of complete pregnancies 11189
Number of ongoing pregnancies 352
Here's the expected remaining duration of a pregnancy as a function of the number of weeks elapsed. After week 36, the process becomes "memoryless".
In [25]:
rem_life1 = sf1.RemainingLifetime()
thinkplot.Plot(rem_life1)
thinkplot.Config(title='Remaining pregnancy length',
                 xlabel='Weeks',
                 ylabel='Mean remaining weeks')
And here's the median remaining time until first marriage as a function of age.
In [26]:
hf, sf2 = EstimateMarriageSurvival(resp6)
In [27]:
func = lambda pmf: pmf.Percentile(50)
rem_life2 = sf2.RemainingLifetime(filler=np.inf, func=func)
    
thinkplot.Plot(rem_life2)
thinkplot.Config(title='Years until first marriage',
                 ylim=[0, 15],
                 xlim=[11, 31],
                 xlabel='Age (years)',
                 ylabel='Median remaining years')

Exercises

Exercise: In NSFG Cycles 6 and 7, the variable cmdivorcx contains the date of divorce for the respondent’s first marriage, if applicable, encoded in century-months.
Compute the duration of marriages that have ended in divorce, and the duration, so far, of marriages that are ongoing. Estimate the hazard and survival curve for the duration of marriage.
Use resampling to take into account sampling weights, and plot data from several resamples to visualize sampling error.
Consider dividing the respondents into groups by decade of birth, and possibly by age at first marriage.
In [28]:
def CleanData(resp):
    """Cleans respondent data.

    resp: DataFrame
    """
    resp.cmdivorcx.replace([9998, 9999], np.nan, inplace=True)

    resp['notdivorced'] = resp.cmdivorcx.isnull().astype(int)
    resp['duration'] = (resp.cmdivorcx - resp.cmmarrhx) / 12.0
    resp['durationsofar'] = (resp.cmintvw - resp.cmmarrhx) / 12.0

    month0 = pd.to_datetime('1899-12-15')
    dates = [month0 + pd.DateOffset(months=cm) 
             for cm in resp.cmbirth]
    resp['decade'] = (pd.DatetimeIndex(dates).year - 1900) // 10
In [29]:
CleanData(resp6)
married6 = resp6[resp6.evrmarry==1]

CleanData(resp7)
married7 = resp7[resp7.evrmarry==1]
In [30]:
# Solution

def ResampleDivorceCurve(resps):
    """Plots divorce curves based on resampled data.

    resps: list of respondent DataFrames
    """
    for _ in range(11):
        samples = [thinkstats2.ResampleRowsWeighted(resp) 
                   for resp in resps]
        sample = pd.concat(samples, ignore_index=True)
        PlotDivorceCurveByDecade(sample, color='#225EA8', alpha=0.1)

    thinkplot.Show(xlabel='years',
                   axis=[0, 28, 0, 1])
In [31]:
# Solution

def ResampleDivorceCurveByDecade(resps):
    """Plots divorce curves for each birth cohort.

    resps: list of respondent DataFrames    
    """
    for i in range(41):
        samples = [thinkstats2.ResampleRowsWeighted(resp) 
                   for resp in resps]
        sample = pd.concat(samples, ignore_index=True)
        groups = sample.groupby('decade')
        if i == 0:
            survival.AddLabelsByDecade(groups, alpha=0.7)

        EstimateSurvivalByDecade(groups, alpha=0.1)

    thinkplot.Config(xlabel='Years',
                     ylabel='Fraction undivorced',
                     axis=[0, 28, 0, 1])
In [32]:
# Solution

def EstimateSurvivalByDecade(groups, **options):
    """Groups respondents by decade and plots survival curves.

    groups: GroupBy object
    """
    thinkplot.PrePlot(len(groups))
    for name, group in groups:
        _, sf = EstimateSurvival(group)
        thinkplot.Plot(sf, **options)
In [33]:
# Solution

def EstimateSurvival(resp):
    """Estimates the survival curve.

    resp: DataFrame of respondents

    returns: pair of HazardFunction, SurvivalFunction
    """
    complete = resp[resp.notdivorced == 0].duration.dropna()
    ongoing = resp[resp.notdivorced == 1].durationsofar.dropna()

    hf = survival.EstimateHazardFunction(complete, ongoing)
    sf = hf.MakeSurvival()

    return hf, sf
In [34]:
# Solution

ResampleDivorceCurveByDecade([married6, married7])
In [ ]: