-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot.py
51 lines (39 loc) · 1.32 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import matplotlib.pyplot as plt
import os
from collections import namedtuple
Losses = namedtuple('losss', ('D_loss', 'G_loss'))
class Plots(object):
def __init__(self, name):
self.counter = 0
self.name = name
def plot_rewards(self, losses, save=False):
x = [i for i in range(len(losses))]
if self.counter == 0:
# plt.ion()
# plt.show()
pass
losses = Losses(*zip(*losses))
d_losses = losses.D_loss
g_losses = losses.G_loss
plt.plot(x, d_losses, label='Descriminator')
plt.plot(x, g_losses, label='Generator')
plt.xlabel('Training iteration')
plt.ylabel('Losses')
# plt.draw()
# plt.pause(0.001)
if save:
if not os.path.isdir('./figures/'):
os.mkdir('./figures/')
plt.savefig('./figures/' + self.name + '.png')
plt.close()
def plot_gan_loss(losses, directory, save=True):
x = [i for i in range(len(losses))]
losses = Losses(*zip(*losses))
d_losses = losses.D_loss
g_losses = losses.G_loss
plt.plot(x, d_losses, label='Descriminator')
plt.plot(x, g_losses, label='Generator')
plt.xlabel('Training iteration')
plt.ylabel('Losses')
plt.savefig(directory + 'GAN_losses.png')
plt.close()