GAN模型的市场情景模拟数据生成¶
本notebook演示如何生成市场情景数据并使用生成对抗网络(GAN)进行建模。我们将:
- 生成不同市场情景下的多资产价格和收益率数据
- 构建和训练GAN模型
- 使用训练好的GAN模型生成新的市场情景
1. 导入必要的库¶
In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import random
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
# 设置随机种子以确保结果可复现
tf.random.set_seed(42)
np.random.seed(42)
random.seed(42)
# 设置绘图样式
plt.style.use('ggplot')
%matplotlib inline
2025-05-13 10:40:08.045875: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2025-05-13 10:40:08.046419: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used. 2025-05-13 10:40:08.048537: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used. 2025-05-13 10:40:08.054618: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1747104008.066652 1540978 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1747104008.069638 1540978 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1747104008.078191 1540978 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1747104008.078208 1540978 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1747104008.078209 1540978 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1747104008.078210 1540978 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. 2025-05-13 10:40:08.081748: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2. 市场情景数据生成函数¶
In [2]:
def generate_market_scenario_data(n_days=504, n_assets=5, n_scenarios=100, seed=42):
"""
生成多种市场情景下多资产的价格和收益率数据,用于训练GAN模型
参数:
n_days: 每个情景的交易日数量 (约2年)
n_assets: 资产数量
n_scenarios: 情景数量
seed: 随机种子
返回:
scenarios_df: 包含所有情景的数据框
asset_corr: 资产间的相关性矩阵
scenario_types: 情景类型说明
"""
np.random.seed(seed)
random.seed(seed)
# 定义资产类型
asset_types = [f"Asset_{i+1}" for i in range(n_assets)]
# 定义市场情景类型
scenario_types = {
"bull_market": "牛市上涨趋势",
"bear_market": "熊市下跌趋势",
"sideways_market": "震荡市场",
"recovery": "市场复苏",
"crash": "市场崩盘",
"high_volatility": "高波动率",
"low_volatility": "低波动率",
"sector_rotation": "板块轮动",
"inflation_shock": "通胀冲击",
"interest_rate_hike": "加息环境"
}
# 生成资产间的相关性矩阵
asset_corr = np.zeros((n_assets, n_assets))
for i in range(n_assets):
for j in range(n_assets):
if i == j:
asset_corr[i, j] = 1.0
else:
# 随机生成0.3到0.9之间的相关性
asset_corr[i, j] = asset_corr[j, i] = np.random.uniform(0.3, 0.9)
# 对相关性矩阵进行Cholesky分解,用于生成相关的随机数
try:
L = np.linalg.cholesky(asset_corr)
except np.linalg.LinAlgError:
# 如果矩阵不是正定的,进行调整
asset_corr = asset_corr + np.eye(n_assets) * 0.01
L = np.linalg.cholesky(asset_corr)
# 日期范围
start_date = datetime(2020, 1, 1)
# dates = [start_date + timedelta(days=i) for i in range(n_days)]
# dates = pd.DatetimeIndex([date for date in dates if date.weekday() < 5])[:n_days]
dates = pd.date_range(start=start_date, periods=n_days, freq='B')
# 初始化情景数据结构
scenarios_data = {}
scenario_assignment = {}
# 为每个情景分配一个类型
scenario_type_names = list(scenario_types.keys())
for i in range(n_scenarios):
scenario_assignment[i] = random.choice(scenario_type_names)
# 生成每个情景的数据
for scenario_id in range(n_scenarios):
scenario_type = scenario_assignment[scenario_id]
# 初始化价格和收益率数组
prices = np.zeros((n_days, n_assets))
returns = np.zeros((n_days, n_assets))
# 设置初始价格
initial_prices = np.random.uniform(50, 200, n_assets)
prices[0] = initial_prices
# 根据情景类型设置不同的市场参数
drift = np.zeros(n_assets) # 趋势项
volatility = np.zeros(n_assets) # 波动率
if scenario_type == "bull_market":
drift = np.random.uniform(0.0002, 0.0008, n_assets) # 日均收益约5%-20%年化
volatility = np.random.uniform(0.007, 0.015, n_assets) # 年化波动率约11%-24%
elif scenario_type == "bear_market":
drift = np.random.uniform(-0.0008, -0.0002, n_assets) # 负日均收益
volatility = np.random.uniform(0.010, 0.020, n_assets) # 较高波动率
elif scenario_type == "sideways_market":
drift = np.random.uniform(-0.0001, 0.0001, n_assets) # 接近零的漂移
volatility = np.random.uniform(0.005, 0.010, n_assets) # 中等波动率
elif scenario_type == "recovery":
# 前1/3为下跌,后2/3为上涨
drift = np.random.uniform(0.0004, 0.0012, n_assets) # 强劲复苏
volatility = np.random.uniform(0.009, 0.018, n_assets)
elif scenario_type == "crash":
# 前半部分平稳,后半部分暴跌
drift = np.random.uniform(-0.0020, -0.0010, n_assets) # 大幅下跌
volatility = np.random.uniform(0.020, 0.035, n_assets) # 高波动
elif scenario_type == "high_volatility":
drift = np.random.uniform(-0.0002, 0.0002, n_assets) # 中性漂移
volatility = np.random.uniform(0.020, 0.040, n_assets) # 非常高的波动率
elif scenario_type == "low_volatility":
drift = np.random.uniform(0.0001, 0.0003, n_assets) # 小正漂移
volatility = np.random.uniform(0.003, 0.007, n_assets) # 低波动率
elif scenario_type == "sector_rotation":
# 各资产轮流表现
drift = np.random.uniform(0.0001, 0.0005, n_assets) # 基础漂移
volatility = np.random.uniform(0.008, 0.015, n_assets)
elif scenario_type == "inflation_shock":
drift = np.random.uniform(-0.0005, 0.0, n_assets) # 略微负向
volatility = np.random.uniform(0.012, 0.022, n_assets) # 高波动
elif scenario_type == "interest_rate_hike":
drift = np.random.uniform(-0.0003, 0.0001, n_assets) # 偏负向
volatility = np.random.uniform(0.010, 0.018, n_assets)
# 生成协方差矩阵
cov_matrix = np.outer(volatility, volatility) * asset_corr
# 模拟价格路径
for t in range(1, n_days):
# 根据情景类型调整特定的模拟参数
current_drift = drift.copy()
current_vol = volatility.copy()
if scenario_type == "recovery" and t < n_days // 3:
# 复苏前期为下跌
current_drift = -current_drift
elif scenario_type == "crash" and t > n_days // 2:
# 崩盘后期大幅下跌
if t > n_days * 3 // 4:
current_drift = drift * 3 # 加大跌幅
current_vol = volatility * 1.5 # 加大波动
elif scenario_type == "sector_rotation":
# 实现板块轮动:每隔一段时间,不同资产表现最好
rotation_period = n_days // 5
leading_sector = (t // rotation_period) % n_assets
for a in range(n_assets):
if a == leading_sector:
current_drift[a] = drift[a] * 3 # 领先板块表现更好
else:
current_drift[a] = drift[a] * 0.5 # 其他板块表现一般
elif scenario_type == "inflation_shock" and t > n_days // 2:
# 通胀冲击后期
current_drift = drift * (1 - 0.5 * (t - n_days // 2) / (n_days // 2)) # 漂移逐渐恶化
# 生成相关的随机收益率
z = np.random.normal(0, 1, n_assets)
correlated_z = np.dot(L, z)
# 计算收益率
daily_returns = current_drift + current_vol * correlated_z
# 添加跳跃组件
if np.random.random() < 0.01: # 1%的概率发生市场跳跃
jump_size = np.random.normal(0, 0.02, n_assets) # 跳跃大小
jump_direction = -1 if scenario_type in ["bear_market", "crash"] else 1
daily_returns += jump_direction * jump_size
# 保存收益率
returns[t] = daily_returns
# 计算新价格
prices[t] = prices[t-1] * (1 + daily_returns)
# 将数据转换为DataFrame并保存
scenario_data = {
"scenario_id": scenario_id,
"scenario_type": scenario_type,
"prices": pd.DataFrame(prices, index=dates, columns=asset_types),
"returns": pd.DataFrame(returns, index=dates, columns=asset_types)
}
scenarios_data[scenario_id] = scenario_data
return scenarios_data, asset_corr, scenario_types
3. 生成市场情景数据¶
In [3]:
# 生成数据示例
scenarios, asset_correlations, scenario_descriptions = generate_market_scenario_data(
n_days=504, n_assets=5, n_scenarios=100
)
# 查看第一个情景的价格数据
first_scenario = next(iter(scenarios.values()))
print(f"情景类型: {scenario_descriptions[first_scenario['scenario_type']]}")
print("\n价格数据预览:")
print(first_scenario["prices"].head())
情景类型: 熊市下跌趋势 价格数据预览: Asset_1 Asset_2 Asset_3 Asset_4 Asset_5 2020-01-01 141.777934 70.924079 93.821697 104.954276 118.410498 2020-01-02 140.363239 71.402748 94.019356 103.333459 116.802965 2020-01-03 139.952089 71.932944 93.604270 104.238376 114.078083 2020-01-06 140.397951 72.638981 93.588545 103.862361 116.134854 2020-01-07 139.137211 72.678051 92.386258 104.516222 115.170769
4. 可视化不同市场情景¶
In [4]:
plt.figure(figsize=(15, 20))
scenario_types_shown = set()
plot_count = 0
for scenario_id, scenario_data in scenarios.items():
scenario_type = scenario_data["scenario_type"]
# 确保每种情景类型只展示一次
if scenario_type in scenario_types_shown or plot_count >= 10:
continue
scenario_types_shown.add(scenario_type)
plot_count += 1
plt.subplot(5, 2, plot_count)
prices_df = scenario_data["prices"]
normalized_prices = prices_df / prices_df.iloc[0] # 归一化价格
for asset in prices_df.columns:
plt.plot(normalized_prices.index, normalized_prices[asset], label=asset)
plt.title(f"{scenario_type}: {scenario_descriptions[scenario_type]}")
plt.xlabel("Date")
plt.ylabel("Normalized Price")
plt.grid(True, alpha=0.3)
if plot_count == 2: # 只在第一行的第二个图显示图例
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.35), ncol=3)
plt.tight_layout()
plt.show()
# 统计各类情景数量
scenario_type_counts = {}
for scenario_id, data in scenarios.items():
scenario_type = data["scenario_type"]
scenario_type_counts[scenario_type] = scenario_type_counts.get(scenario_type, 0) + 1
print(f"已生成{len(scenarios)}个市场情景,每个情景包含{len(first_scenario['prices'])}个交易日和{len(first_scenario['prices'].columns)}个资产")
print("\n情景类型分布:")
for scenario_type, count in scenario_type_counts.items():
print(f" {scenario_descriptions[scenario_type]}: {count}个情景 ({count/len(scenarios)*100:.1f}%)")
已生成100个市场情景,每个情景包含504个交易日和5个资产 情景类型分布: 熊市下跌趋势: 14个情景 (14.0%) 牛市上涨趋势: 9个情景 (9.0%) 市场崩盘: 11个情景 (11.0%) 市场复苏: 16个情景 (16.0%) 震荡市场: 8个情景 (8.0%) 通胀冲击: 8个情景 (8.0%) 加息环境: 8个情景 (8.0%) 低波动率: 9个情景 (9.0%) 板块轮动: 6个情景 (6.0%) 高波动率: 11个情景 (11.0%)
5. 为GAN模型准备数据¶
In [5]:
def prepare_data_for_gan(scenarios, n_days_sequence=20):
"""
从生成的情景数据中准备适合GAN训练的数据
参数:
scenarios: 情景数据字典
n_days_sequence: 每个训练样本的天数序列长度
返回:
X_train: 训练数据,形状为 [样本数, 序列长度, 特征数]
scaler: 用于反向转换的MinMaxScaler
"""
sequences = []
# 遍历所有情景
for scenario_id, scenario_data in scenarios.items():
prices_df = scenario_data["prices"]
returns_df = scenario_data["returns"]
# 将价格和收益率数据转换为numpy数组
prices = prices_df.values
returns = returns_df.values
# 计算归一化价格
normalized_prices = prices / prices[0:1]
# 从情景中提取序列
for i in range(len(normalized_prices) - n_days_sequence):
# 提取价格序列
price_sequence = normalized_prices[i:i+n_days_sequence]
# 提取收益率序列
returns_sequence = returns[i:i+n_days_sequence]
# 合并为特征序列 [序列长度, 特征数]
# 特征包括归一化价格和收益率
feature_sequence = np.concatenate([price_sequence, returns_sequence], axis=1)
sequences.append(feature_sequence)
# 将列表转换为numpy数组
sequences = np.array(sequences)
# 处理极端值,使用MinMaxScaler将数据缩放到[-1, 1]范围
n_samples, seq_length, n_features = sequences.shape
sequences_reshaped = sequences.reshape(n_samples * seq_length, n_features)
scaler = MinMaxScaler(feature_range=(-1, 1))
sequences_scaled = scaler.fit_transform(sequences_reshaped)
# 将数据重新整形为3D
X_train = sequences_scaled.reshape(n_samples, seq_length, n_features)
return X_train, scaler
# 准备GAN训练数据
X_train, scaler = prepare_data_for_gan(scenarios, n_days_sequence=20)
print(f"训练数据形状: {X_train.shape}")
训练数据形状: (48400, 20, 10)
6. 构建GAN模型¶
In [6]:
def build_gan_model(seq_length, n_features):
"""
构建GAN模型
参数:
seq_length: 序列长度
n_features: 特征数量
返回:
generator: 生成器模型
discriminator: 判别器模型
gan: 完整GAN模型
"""
# 生成器模型
generator = Sequential(name="Generator")
generator.add(Input(shape=(100,))) # 潜在空间维度
# 隐藏层
generator.add(Dense(128))
generator.add(LeakyReLU(0.2))
generator.add(BatchNormalization())
generator.add(Dense(256))
generator.add(LeakyReLU(0.2))
generator.add(BatchNormalization())
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(BatchNormalization())
# 输出层,生成序列数据
generator.add(Dense(seq_length * n_features, activation='tanh'))
generator.add(tf.keras.layers.Reshape((seq_length, n_features)))
# 判别器模型
discriminator = Sequential(name="Discriminator")
discriminator.add(Input(shape=(seq_length, n_features)))
discriminator.add(tf.keras.layers.Flatten())
# 隐藏层
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(128))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
# 输出层,判断真假
discriminator.add(Dense(1, activation='sigmoid'))
# 编译判别器
discriminator.compile(loss='binary_crossentropy',
optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
metrics=['accuracy'])
# 构建GAN
discriminator.trainable = False # 训练生成器时冻结判别器
gan_input = Input(shape=(100,))
fake_sequence = generator(gan_input)
gan_output = discriminator(fake_sequence)
gan = Model(gan_input, gan_output, name="GAN")
gan.compile(loss='binary_crossentropy',
optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
return generator, discriminator, gan
# 获取序列长度和特征数量
seq_length = X_train.shape[1]
n_features = X_train.shape[2]
# 构建GAN模型
generator, discriminator, gan = build_gan_model(seq_length, n_features)
# 显示模型摘要
print("生成器模型摘要:")
generator.summary()
print("\n判别器模型摘要:")
discriminator.summary()
生成器模型摘要:
E0000 00:00:1747104010.949431 1540978 cuda_executor.cc:1228] INTERNAL: CUDA Runtime error: Failed call to cudaGetRuntimeVersion: Error loading CUDA libraries. GPU will not be used.: Error loading CUDA libraries. GPU will not be used. W0000 00:00:1747104010.951240 1540978 gpu_device.cc:2341] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
Model: "Generator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 128) │ 12,928 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ leaky_re_lu (LeakyReLU) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization │ (None, 128) │ 512 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 256) │ 33,024 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ leaky_re_lu_1 (LeakyReLU) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_1 │ (None, 256) │ 1,024 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 512) │ 131,584 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ leaky_re_lu_2 (LeakyReLU) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_2 │ (None, 512) │ 2,048 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_3 (Dense) │ (None, 200) │ 102,600 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ reshape (Reshape) │ (None, 20, 10) │ 0 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 283,720 (1.08 MB)
Trainable params: 281,928 (1.08 MB)
Non-trainable params: 1,792 (7.00 KB)
判别器模型摘要:
Model: "Discriminator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ flatten (Flatten) │ (None, 200) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_4 (Dense) │ (None, 512) │ 102,912 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ leaky_re_lu_3 (LeakyReLU) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_5 (Dense) │ (None, 256) │ 131,328 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ leaky_re_lu_4 (LeakyReLU) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_6 (Dense) │ (None, 128) │ 32,896 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ leaky_re_lu_5 (LeakyReLU) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_2 (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_7 (Dense) │ (None, 1) │ 129 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 267,265 (1.02 MB)
Trainable params: 0 (0.00 B)
Non-trainable params: 267,265 (1.02 MB)
7. 训练GAN模型¶
In [7]:
def train_gan(gan, generator, discriminator, X_train, epochs=300, batch_size=32, latent_dim=100):
"""
训练GAN模型
参数:
gan: GAN模型
generator: 生成器模型
discriminator: 判别器模型
X_train: 训练数据
epochs: 训练轮数
batch_size: 批量大小
latent_dim: 潜在空间维度
返回:
losses: 包含生成器和判别器损失的字典
"""
# 准备训练数据
n_samples = X_train.shape[0]
# 损失历史
losses = {
"d_loss": [],
"g_loss": []
}
for epoch in range(epochs):
# 训练判别器
# 随机选择批量真实数据
idx = np.random.randint(0, n_samples, batch_size)
real_sequences = X_train[idx]
# 生成批量假数据
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_sequences = generator.predict(noise)
# 给判别器的标签
real_labels = np.ones((batch_size, 1)) * 0.9 # 平滑标签
fake_labels = np.zeros((batch_size, 1))
# 训练判别器
d_loss_real = discriminator.train_on_batch(real_sequences, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_sequences, fake_labels)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# 训练生成器
# 生成新的随机噪声
noise = np.random.normal(0, 1, (batch_size, latent_dim))
# 训练生成器 (通过GAN模型,判别器权重被冻结)
g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
# 存储损失
losses["d_loss"].append(d_loss[0])
losses["g_loss"].append(g_loss)
# 打印进度
if epoch % 50 == 0:
print(f"Epoch {epoch}/{epochs}, d_loss: {d_loss[0]:.4f}, g_loss: {g_loss:.4f}")
return losses
# 训练GAN模型
losses = train_gan(gan, generator, discriminator, X_train, epochs=300, batch_size=32, latent_dim=100)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 102ms/step
/home/kk/softwares/anaconda/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py:82: UserWarning: The model does not have any trainable weights. warnings.warn("The model does not have any trainable weights.")
Epoch 0/300, d_loss: 0.7025, g_loss: 0.5477 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step Epoch 50/300, d_loss: 0.7440, g_loss: 0.5786 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step Epoch 100/300, d_loss: 0.7768, g_loss: 0.5315 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 27ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step Epoch 150/300, d_loss: 0.8149, g_loss: 0.4867 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 28ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step Epoch 200/300, d_loss: 0.8571, g_loss: 0.4507 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step Epoch 250/300, d_loss: 0.9001, g_loss: 0.4183 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step
8. 可视化训练损失¶
In [8]:
def plot_losses(losses):
"""绘制训练损失曲线"""
plt.figure(figsize=(10, 5))
plt.plot(losses["d_loss"], label="Discriminator Loss")
plt.plot(losses["g_loss"], label="Generator Loss")
plt.title("GAN Training Losses")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 可视化训练损失
plot_losses(losses)
9. 生成新的市场情景¶
In [9]:
def plot_generated_sequences(generator, n_samples=3, latent_dim=100, n_assets=5, scaler=None):
"""
生成并绘制样本序列
参数:
generator: 训练好的生成器模型
n_samples: 生成样本数量
latent_dim: 潜在空间维度
n_assets: 资产数量
scaler: 用于反向转换的缩放器
"""
# 生成随机噪声
noise = np.random.normal(0, 1, (n_samples, latent_dim))
# 生成序列
gen_sequences = generator.predict(noise)
# 反向转换数据(如果提供了scaler)
if scaler:
seq_length = gen_sequences.shape[1]
n_features = gen_sequences.shape[2]
gen_sequences_reshaped = gen_sequences.reshape(-1, n_features)
gen_sequences_inverse = scaler.inverse_transform(gen_sequences_reshaped)
gen_sequences = gen_sequences_inverse.reshape(n_samples, seq_length, n_features)
# 绘制生成的序列
fig, axes = plt.subplots(n_samples, 1, figsize=(15, 5*n_samples))
for i in range(n_samples):
if n_samples > 1:
ax = axes[i]
else:
ax = axes
# 提取价格数据 (假设前n_assets列是价格数据)
price_data = gen_sequences[i, :, :n_assets]
# 绘制每个资产的价格
for j in range(n_assets):
ax.plot(range(gen_sequences.shape[1]), price_data[:, j], label=f'Asset_{j+1}')
ax.set_title(f'生成的市场情景 {i+1}')
ax.set_xlabel('时间步')
ax.set_ylabel('归一化价格')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 生成并可视化市场情景
plot_generated_sequences(generator, n_samples=3, n_assets=5, scaler=scaler)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step
10. 生成更多情景,比较与真实情景的相似度¶
In [10]:
def compare_real_and_generated(X_train, generator, n_samples=3, latent_dim=100, n_assets=5, scaler=None):
"""
比较真实和生成的市场情景
参数:
X_train: 训练数据
generator: 生成器模型
n_samples: 样本数量
latent_dim: 潜在空间维度
n_assets: 资产数量
scaler: 用于反向转换的缩放器
"""
# 随机选择真实样本
indices = np.random.randint(0, X_train.shape[0], n_samples)
real_samples = X_train[indices]
# 生成样本
noise = np.random.normal(0, 1, (n_samples, latent_dim))
generated_samples = generator.predict(noise)
# 如果提供了scaler,反向转换数据
if scaler:
# 反转真实样本
seq_length = real_samples.shape[1]
n_features = real_samples.shape[2]
real_reshaped = real_samples.reshape(-1, n_features)
real_inverse = scaler.inverse_transform(real_reshaped)
real_samples = real_inverse.reshape(n_samples, seq_length, n_features)
# 反转生成样本
gen_reshaped = generated_samples.reshape(-1, n_features)
gen_inverse = scaler.inverse_transform(gen_reshaped)
generated_samples = gen_inverse.reshape(n_samples, seq_length, n_features)
# 绘制比较图
fig, axes = plt.subplots(n_samples, 2, figsize=(15, 5*n_samples))
for i in range(n_samples):
# 真实样本
ax1 = axes[i, 0]
real_price = real_samples[i, :, :n_assets]
for j in range(n_assets):
ax1.plot(range(seq_length), real_price[:, j], label=f'Asset_{j+1}')
ax1.set_title(f'真实市场情景 {i+1}')
ax1.set_xlabel('时间步')
ax1.set_ylabel('归一化价格')
ax1.grid(True, alpha=0.3)
# 生成样本
ax2 = axes[i, 1]
gen_price = generated_samples[i, :, :n_assets]
for j in range(n_assets):
ax2.plot(range(seq_length), gen_price[:, j], label=f'Asset_{j+1}')
ax2.set_title(f'生成的市场情景 {i+1}')
ax2.set_xlabel('时间步')
ax2.set_ylabel('归一化价格')
ax2.grid(True, alpha=0.3)
# 只在第一行显示图例
handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.0), ncol=5)
plt.tight_layout(rect=[0, 0, 1, 0.96]) # 为图例留出空间
plt.show()
# 比较真实和生成的市场情景
compare_real_and_generated(X_train, generator, n_samples=3, n_assets=5, scaler=scaler)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step
11. 总结与结论¶
在本notebook中,我们完成了以下工作:
数据生成:创建了一个函数,用于生成多种市场情景下的多资产价格和收益率数据。这些情景包括牛市、熊市、震荡市场、市场复苏、市场崩盘等不同类型。
数据预处理:将生成的数据转换为适合GAN训练的格式,并进行了标准化处理。
模型构建:构建了一个生成对抗网络(GAN),包括生成器和判别器两个部分。
模型训练:训练了GAN模型,使其能够学习市场数据的分布。
结果生成与评估:使用训练好的生成器生成新的市场情景,并与真实情景进行了比较。
GAN模型能够学习并生成类似于真实市场数据的金融时间序列,这种方法可以用于:
- 市场风险估计和压力测试
- 投资策略的回测和优化
- 金融模型的稳健性测试
- 数据增强,为深度学习模型提供更多训练样本
需要注意的是,GAN生成的数据虽然在统计特性上与真实数据相似,但不能直接用于预测未来市场走势。它的主要价值在于提供多样化的市场情景模拟,帮助金融从业者和研究人员更好地理解和应对各种可能的市场环境。