for epoch in range(epochs):
# Генерация случайных векторов из латентного пространства
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
# Генерация сгенерированных изображений генератором
generated_images = generator(noise)
# Получение случайных реальных изображений из обучающего набора
image_batch = train_images[np.random.randint(0, train_images.shape[0], size=batch_size)]
# Сборка батча из реальных и сгенерированных изображений
X = np.concatenate([image_batch, generated_images])
# Создание векторов меток для реальных и сгенерированных изображений
y_dis = np.zeros(2 * batch_size)
y_dis[:batch_size] = 0.9 # односторонний мягкий ярлык для гладкости
# Обучение дискриминатора на батче
discriminator.trainable = True
d_loss = discriminator.train_on_batch(X, y_dis)
# Обучение генератора
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
y_gen = np.ones(batch_size)
discriminator.trainable = False
g_loss = gan.train_on_batch(noise, y_gen)
if epoch % 100 == 0:
print(f"Epoch: {epoch}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}")
# Обучение GAN
gan = tf.keras.Sequential([generator, discriminator])
gan.compile(loss='binary_crossentropy', optimizer=generator_optimizer)
train_gan()
```
Код представляет собой простую реализацию генеративной сети (GAN) для генерации реалистичных изображений с использованием библиотек TensorFlow и Keras в Python. Давайте подробно опишем каждую часть кода:
1. Загрузка данных MNIST:
– Загружается набор данных MNIST с рукописными цифрами с помощью функции `tf.keras.datasets.mnist.load_data()`.
– Обучающие изображения сохраняются в переменной `train_images`, а метки классов (которые в данном случае не используются) – в переменной `_`.
– Изображения преобразуются в одномерный формат и нормализуются в диапазоне [-1, 1], чтобы облегчить обучение модели.
2. Определение гиперпараметров:
– `random_dim`: размерность входного шумового вектора (латентного пространства), который будет использоваться для генерации изображений.
– `epochs`: количество эпох обучения GAN.
– `batch_size`: размер батча, используемого для обучения на каждой итерации.
3. Создание генератора (`build_generator`):
– Генератор представляет собой нейронную сеть, которая принимает случайный шум или вектор из латентного пространства и генерирует синтетические изображения.