统计学基础 | 因果推断之反事实生成对抗网络

酥酥 发布于 2025-02-17 24 次阅读


  反事实生成对抗网络(Counterfactual GANs, CF-GANs)是统计学因果推断中的一种深度学习方法,主要用于模拟在不同处理(或干预)情况下个体的潜在结果(Potential Outcomes)。它结合了生成对抗网络(GANs)和因果推断的反事实分析,以更精确地估计因果效应。

1. 反事实推断的核心问题

    因果推断的核心挑战在于反事实(Counterfactual)无法直接观察。例如,在医疗试验中,对于一个接受了治疗T=1的病人,我们只能观察到他的健康状况Y(1),但无法知道如果他未接受治疗(T=0),他的健康状况Y(0)会如何。这就形成了“缺失的反事实”问题。

目标是估计个体处理效应(Individual Treatment Effect, ITE):

ITE=Y(1)−Y(0)

但由于只能观察到其中之一,估计这一差值变得困难。

2. 生成对抗网络(GANs)的引入

    生成对抗网络(GANs)由一个生成器(Generator, G)和一个判别器(Discriminator, D)组成。生成器G试图生成逼真的数据,而判别器D试图区分真实数据和生成的数据。这种对抗训练使得G逐渐学习到数据的真实分布。

    在反事实推断中,GANs可用于学习数据分布,并生成潜在的反事实结果。CF-GANs的目标是利用GANs生成个体的反事实结果,从而更准确地估计因果效应。

3. CF-GANs 的架构

CF-GANs 主要包括三个部分:

  1. 生成器 G:

    • 以个体特征X和处理T作为输入,学习生成缺失的反事实结果,即:Y^(T)=G(X,T,Z),其中Z是噪声变量。
    • 生成器试图使合成的反事实数据与真实数据分布匹配。
  2. 判别器 D:

    • 试图区分真实的观察结果Y(T)和生成器生成的反事实结果Y^(T)。
    • 通过对抗训练,迫使生成器生成更逼真的反事实数据。
  3. 重构损失(Reconstruction Loss)

    • 由于个体的真实观察结果是已知的,CF-GANs还会引入一个重构损失,确保当 GANs 生成已观察的结果时,与真实值保持一致。
    • 公式上,若个体i观测到Y(T),那么要求:G(X,T)≈Y(T)G(X, T)

4. CF-GANs 的训练过程

CF-GANs 训练过程包括以下步骤:

  1. 初始化生成器和判别器的参数。
  2. 使用已观测数据训练判别器,使其能够分辨真实观察值和生成值。
  3. 训练生成器,使其生成的反事实数据与真实数据分布匹配。
  4. 利用观测数据计算重构损失,确保生成的结果尽可能接近真实数据。
  5. 迭代优化,直到生成器能够逼真地模拟反事实分布。

5. CF-GANs 在因果推断中的应用

  • 医疗与精准医学:估计某种治疗对个体病人的影响,预测未接受治疗时的病情变化。
  • 经济学与政策评估:模拟不同政策实施下的经济结果,例如评估最低工资政策对就业率的影响。
  • 推荐系统:估计如果用户未选择某个产品,他们是否会选择另一种产品,从而优化个性化推荐。

示例

    我们可以用一个模拟医疗试验的数据集来演示 CF-GANs 的原理。假设我们有一个数据集,包含患者的特征、是否接受了治疗(T),以及治疗后的健康状况(Y)。我们的目标是使用 CF-GANs 生成反事实结果,并估计个体治疗效应(ITE)。

示例数据

  • 特征(X):患者的年龄、血压、体重等。
  • 处理(T):是否接受治疗(0 或 1)。
  • 观察结果(Y):患者的健康评分(0-100),表示健康状况。

    接下来,我们用 Python 代码模拟一个小型数据集,并搭建一个简单的 CF-GANs 框架。

				
					
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Step 1: 生成模拟数据
def generate_data(n=1000):
    np.random.seed(42)
    X = np.random.randn(n, 3)  # 三个特征(例如:年龄、血压、体重)
    T = np.random.binomial(1, 0.5, n)  # 处理(0 或 1)
    Y0 = 50 + 10 * X[:, 0] + 5 * X[:, 1] + np.random.randn(n) * 5  # 未治疗的结果
    Y1 = Y0 + 10 + np.random.randn(n) * 2  # 治疗后的结果
    Y = T * Y1 + (1 - T) * Y0  # 观察到的结果
    return X, T, Y, Y0, Y1

X, T, Y, Y0, Y1 = generate_data()

dataset = TensorDataset(torch.tensor(X, dtype=torch.float32), 
                         torch.tensor(T, dtype=torch.float32).unsqueeze(1),
                         torch.tensor(Y, dtype=torch.float32).unsqueeze(1))
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Step 2: 定义 CF-GANs 结构
class Generator(nn.Module):
    def __init__(self, input_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + 1, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )

    def forward(self, x, t):
        return self.model(torch.cat([x, t], dim=1))

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + 1, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, x, y):
        return self.model(torch.cat([x, y], dim=1))

# 初始化模型
generator = Generator(input_dim=3)
discriminator = Discriminator(input_dim=3)

# Step 3: 训练 CF-GANs
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.001)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.001)

for epoch in range(100):
    for x_batch, t_batch, y_batch in train_loader:
        # 训练判别器
        real_labels = torch.ones(y_batch.size(0), 1)
        fake_labels = torch.zeros(y_batch.size(0), 1)

        d_optimizer.zero_grad()
        real_loss = criterion(discriminator(x_batch, y_batch), real_labels)

        fake_y = generator(x_batch, 1 - t_batch)
        fake_loss = criterion(discriminator(x_batch, fake_y.detach()), fake_labels)

        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()

        # 训练生成器
        g_optimizer.zero_grad()
        g_loss = criterion(discriminator(x_batch, fake_y), real_labels)
        g_loss.backward()
        g_optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch {epoch}, D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')

# Step 4: 生成反事实结果
def generate_counterfactuals(x, t):
    t_cf = 1 - t  # 反事实处理
    return generator(torch.tensor(x, dtype=torch.float32), torch.tensor(t_cf, dtype=torch.float32).unsqueeze(1)).detach().numpy()

x_test = X[:5]  # 取前5个样本
cf_outcomes = generate_counterfactuals(x_test, T[:5])

print("实际观察到的 Y:", Y[:5])
print("反事实生成的 Y:", cf_outcomes)
				
			

分析 CF-GANs 的执行步骤

  1. 数据生成

    • 我们模拟了1000名患者的特征(年龄、血压、体重)。
    • 随机分配治疗T(0 或 1)。
    • 定义了真实的未治疗和治疗后的健康状况 Y0,Y1。
    • 观测值Y由T决定,即Y=T⋅Y1+(1−T)⋅Y0。
  2. CF-GANs 训练

    • 生成器G生成反事实健康状况(即未观察到的治疗或非治疗结果)。
    • 判别器D尝试区分真实观察值和生成器生成的反事实数据。
    • 通过对抗训练,使得生成的反事实数据逼真。
  3. 反事实估计

    • 训练后,我们用CF-GANs生成某些患者的反事实健康状况。
    • 例如,如果一个患者接受了治疗T=1,我们生成他未接受治疗T=0时的健康状况(反事实)。
  4. ITE 计算

    • 通过计算ITE=Y1−Y0来估计个体治疗效应。

可视化

				
					import matplotlib.pyplot as plt
import seaborn as sns

# 选取部分样本
num_samples = 50
x_test = X[:num_samples]
t_test = T[:num_samples]
y_test = Y[:num_samples]
cf_outcomes = generate_counterfactuals(x_test, t_test)
				
			

1. 真实 vs. 生成的反事实健康状况(散点图)

目标:比较观察到的健康状况Y和CF-GANs生成的反事实Y^。
图形

  • 横轴:患者索引(或个体特征)。
  • 纵轴:健康评分Y。
  • 用 蓝色点 表示观察到的Y(真实值)。
  • 用 红色点 表示CF-GANs生成的反事实 Y^。
  • 可连接真实和反事实点,以显示变化。
				
					
# 1. 真实 vs. 反事实散点图
plt.figure(figsize=(8, 5))
plt.scatter(range(num_samples), y_test, color='blue', label='真实 Y')
plt.scatter(range(num_samples), cf_outcomes, color='red', label='反事实 $\hat{Y}$')
plt.xlabel('样本索引')
plt.ylabel('健康评分')
plt.title('真实 vs. 反事实 评分')
plt.legend()
plt.show()
				
			

2. 个体治疗效应(ITE)分布图

目标:展示GANs生成的个体治疗效应(ITE)估计结果。

ITE=Y^(1)−Y^(0)

图形

  • 直方图或核密度估计图(KDE),展示ITE在总体上的分布情况。
  • 直方图可观察不同个体对治疗的反应是否存在差异。
				
					
# 2. 个体治疗效应 ITE 分布
ITE = cf_outcomes.flatten() - y_test
plt.figure(figsize=(8, 5))
sns.histplot(ITE, bins=20, kde=True)
plt.xlabel('ITE (个体治疗效应)')
plt.ylabel('样本数')
plt.title('ITE 估计分布')
plt.show()
				
			

3. 真实 vs. 生成的分布比较(直方图)

目标:检查GANs生成的健康评分分布是否匹配真实数据分布。
图形

  • 画出 真实的Y(0)和Y(1)分布(蓝色)。
  • 画出 CF-GANs 生成的Y^(0)和Y^(1)分布(红色)。
  • 直方图叠加,若两个分布吻合,表示CF-GANs生成的反事实可靠。
				
					
# 3. 真实 vs. 生成分布对比
plt.figure(figsize=(8, 5))
sns.histplot(Y[T == 0], bins=20, kde=True, color='blue', label='真实 Y(0)', alpha=0.6)
sns.histplot(Y[T == 1], bins=20, kde=True, color='green', label='真实 Y(1)', alpha=0.6)
sns.histplot(cf_outcomes.flatten(), bins=20, kde=True, color='red', label='GAN 生成的反事实 $\hat{Y}$', alpha=0.6)
plt.xlabel('健康评分')
plt.ylabel('样本数')
plt.title('真实 vs. 生成的 Y 分布')
plt.legend()
plt.show()
				
			

可见,CF-GANs通过对抗训练学习数据分布,生成逼真的反事实结果,从而支持因果推断。这种方法在医疗、政策分析等领域具有广泛应用。