This post was summarized from my final project @Stanford CS229.
The paper aimed to propose a framework that leverages machine learning methods to utilize information from multiple data sources, with the ultimate goal being able to generate a de-biased data layer that allows health data scientists/researchers to perform analyses on.
As an demonstration of the concept, I assumed a hypothetical goal:
To estimate the share of a particular item (A) against a list of competing items, possibly given a set of features.
This looks like a typical problem we’d solve with statistical inference. I tried to tackle the prompt with the following three perspectives and their impact on the need of data centralization:
We simulated different real world data scenarios leveraging the Fashion MNIST dataset
For each study scenario, We sampled three datasets (DS-1, 2, 3) randomly with different size assumptions setting specific seed from the true population. Random Gaussian noise within individual datasets was assumed as the embedded data batch effect beyond the random sample noises. Selection biases within individual datasets were assumed followed the specifying distribution in Table 1.
To illustrate the underlying research question and the unbiased sample dataset in equation:
\[Y_{j}^{(i)} \sim b_{data_j}^{(i)} + \beta X^{(i)} + \epsilon^{(i)}_j\]where
\(Y^{(i)}_j \sim \textbf{Multinomial}(n, p_1, ..., p_{10})\) : n is sample size, and $p_1$ to $p_{10}$ are proportions of the 10 product labels
\(X^{(i)}_j \sim\) Fashion MNIST (28x28) feature space
\(\epsilon^{(i)}_j \sim\) Fashion MNIST (28x28) sample noise
\(b_{data_j} \sim \mathcal{N}(0, \Sigma_{data_j})\): assumed batch-effect ($\Sigma_{data_1}$, $\Sigma_{data_2}$, $\Sigma_{data_3}$) $=$ (0.1, 0.2, 0.1)
Additional category selection biasness was assumed and specified in \textbf{Table 1}, which was illustrated as variable $Z_j$ here:
\[Y_{j}^{(i)} \sim b_{data_j}^{(i)} + \beta X^{(i)} + Z_j^{(i)} + \epsilon^{(i)}_j\]where
\(Z_j^{(i)} | K = k \sim \textbf{Bernouli}(p_{k,j})\): k is the hidden category in real world that attributed to selection bias within a given data set j (if $Z_j$ = 1 then sample was observed); \(\mbox{Item}_1 ... \mbox{Item}_{10}\) were summarized by category K without additional noise.
The biased datasets were aimed to simulate the real world scenarios that datasets observed are often non-random subset of the true population of interest. Specifically, DS-1 was assumed as a biased set toward non-target classes; DS-2 as an unbiased set and DS-3 as a biased set toward target class.
Algorithms that predict the outcome of interest. This component aimed to study the scenario when product labels are unknown and a model is used to predict the product label. Different state-of-the-art approaches were studied here given embeded data biasness and batch noise. Max absolute scaling
The algorithm that generates the centralized data distribution and predict the outcome of interest. We utilized and modified a GAN module (SGAN) from an open-sourced Github package (Pytorch-GAN)
SGAN is a semi-supervised GAN model that expanded from the orignal GAN model
Three assumptions on DS-1 (data distribution biased toward non-target), DS-2 (unbiased dataset) and DS-3 (data distribution biased toward target) mixture with 10%, 50% and 90% population set added as DS-2 were experimented. The three data mixtures were purposely selected to ensure that the overall mixture sizes were similar to one another and to avoid additional noise attributed from different size of the samples. The test sets were what we held out as unbiased sets from the individual datasets before selection biasness applied.
Experiment #1 (Data Mixture 1): 50% DS-1, 50% DS-2, 50% DS-3
Experiment #2 (Data Mixture 2): 90% DS-1, 10% DS-2, 50% DS-3
Experiment #3 (Data Mixture 3): 50% DS-1, 10% DS-2, 90% DS-3
Key-takeaway #1: A row-stacking data mixture would always benefit direct inference when the target label was known.
Key-takeaway #2: As the less biased dataset dominating the data mixture, we made less biased data inference.
Key-takeaway #3: Model performance among different data mixtures are all within a range of 85 - 87% accuracy. However, we observed that the model performance increased as the size of the unbiased data mixture - 2 increased. This might indicate that the increase in unbiased dataset in model training avoided the model from overfitting.
Key-takeaway #4: We encountered common GAN model training issue that the discriminator for label classification became too strong and the generator gave-up on improving its data generation to fool discriminator.
Findings of this study further confirmed our intuition that the quality and volume aggregation of data sources might be the most crucial parts in industry (e.g. health care) that are highly relied on real-world data sources. Across all simulated data biasness scenarios, we found that centralizing all datasets would almost always led to less data attributed biasness compared to individual largest datasets. Though the author failed to prove that involving a GAN framework as a more powerful in reducing overall data + model biasness, we explored and learned a lot of GAN concepts during this process. One immediate next step could have been to do more experiments in tuning the SGAN model to improve the generator performance. Beyond this, exploring GAN framework trained on federated data system (e.g. datasets sit on different vendors’ server) remains an interesting area that author would love to explore next. Last but not least, the original proposed framework was hoping to scope a dynamic reinforcement learning framework that incorporates the achievements here. As RWD datasets are commonly refreshed periodically, we expect including a multi-arm bandit like reinforcement learning component could help learning the data drifts more in time and allocate the data weights based on target outcome of interest, where the rewards being levels of biasness reduction from a bench-marking population statistics.