Near zero loss / high capacity hypernetwork artstyle training #7011
Replies: 3 comments 1 reply
-
This is all really interesting, though I think the goal of generating identical images to the training is a bit weird. Isn't the point of training models to produce new images with the style and concepts of the training data(without replicating it)? Even if it can generalize I'd be scared of accidentally plagiarizing something when using it. Either way the stuff about the improper normalization layer implementation is neat and should probably be fixed if the current implementation is wrong. Also, how many steps did it take to train your example network? I'm curious how much of a speedup the correct normalization layer gives. |
Beta Was this translation helpful? Give feedback.
-
Interesting but... thoses are hypernetwork are probably way to much overfitted to be useful. |
Beta Was this translation helpful? Give feedback.
-
If anyone is still interested in this ancient technology: https://huggingface.co/lmganon123/pochi_hypernet Here is the best I could do with it. It works but I am not sure if it is better than a LORA. |
Beta Was this translation helpful? Give feedback.
-
Training data:
Training preview:
I am gonna start writing it down now since I may start to forget things. Basically I think I finally made something very close to a perfect hypernetwork. By perfect I mean a hypernetwork that can make replicas of training data if you use the same prompt used during training for each picture. See: #2670 (comment) but in this case it is not 1 training picture but 72. While as I have shown in that 1 training picture example the network can be overpowering when you use the same prompt, changing prompt shows that it can generalize and retain style.
I am probably 20 commits behind and there is a lot of stuff I modified in the file during my trials and errors so I will just attach my hypernetwork.py file.
hypernetwork.zip
If you just copy the file over and it works you can skip the next section.
Most important changes that allow you to actually get to near zero loss
Line 82:
if type(layer) == torch.nn.Linear:
This stops the change to default initialization of norm layer. People have been saying that norm layers just slow down training. In reality norm layer was initialized improperly. Norm layers not only reduce overfitting, but they also significantly speed up training. With proper initialization (default initialization with weights set to 1) you can actually see the training loss per epoch start decreasing. And the outputs also start to resemble training data much faster. Also I never had a gradient explosion with norm layer and cosine annealing regardless of how long I trained the network.
Line 463:
optimizer = torch.optim.AdamW(params=weights, lr=3e-4, weight_decay = 0.05, amsgrad = True)
Amsgrad is helpful. I have seen it speed up the initial phase of learning where you go from blobs to actual shapes. It also controls the norms of your weights. WIthout it the norms grow much faster which can mess up the optimal learn rate. It unfortunately increases vram.
Line 743:
Discussed here: #2670 (comment) . Wasted a lot of time because of this. If you don't want to fork rng not making previews is an option. I also think that previews with -1 seed should be fine. Making previews with a fixed seed without forked rng breaks hypernetwork training.
https://textual-inversion.github.io/
My understanding based on above (can be wrong, would love to be corrected) is that your hypernetwork modifies the noised picture at 4 steps during denoising to change the final denoised picture into your training data. If you don't reset the seed (correct method of training) then each step of your training is a random noised picture, that would be generated with the prompt you train with for that picture. And your hypernetwork tries to modify those noised picture so the denoising makes it look like training data. If you set the seed at each step then you give it the same exact noised picture to modify. Your loss graph will look great with this but your hypernetwork is unprepared for random noised pictures, so it will never create anything good.
Settings used in training to reach near zero loss
First annealing period of roughly 1000 steps and then multiplied by 2 at each restart (change with line 515 in my hypernetwork file)
hypernetwork_loss.csv
Loss graph:
Purple is learn rate plotted on right axis. Blue is mean loss from last 3 epochs plotted on left. In the end the loss oscillates around 0.02
And here are data points.
Aaaaand I am still very unhappy because it is 90% there but it is still not what I want. I tried a lot of things to lower the loss further but it doesn't work. And the results on this loss level are far from perfect.
It is very easy to tell which one is original and which one is the copy. And look at that face:
Also the shading and coloring is atrocious. I want that final 2% of perfection.
A few remarks and thoughts
Great paper: https://arxiv.org/pdf/1706.05350.pdf
Opened my eyes on a lot of things. Page 7 is especially enlightening. Basically layer normalization removes the effect of weight norms on generalization performance (overcooking). The problem with those neat graphs there is that the general trend applies but optimal values are probably off. Graphs in paper were done for batch size 128. High batch size requires a much bigger weight decay than 1 I used.
I wish I knew what is the optimal structure of hypernetwork for this. Bigger is obviously always better with layer norm but I don't know what would be optimal. If you think about it the network I did generates the same picture for different seeds and even dropped out tags. So it is hard for me to imagine you could contain all that information with a large feed forward like 1, 8, 8 ,1. You need more layers. But how should they be structured and maybe some convolutional layers would be better?
I am pretty sure you need to use tags and deepdanbooru is great for this with NAI if NAI was trained on danbooru tags. Tags will be generating differently noised pictures that gives a chance for your hypernetwork to distinguish between easier. I just don't know if more tags is better or not. Maybe the reason I can't go lower with loss is because too many tags generate very random results.
Beta Was this translation helpful? Give feedback.
All reactions