-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] There is the error in timm/train.py when i use the Webdataset (timm/imagent-w21-wds in huggingface) with class map #2154
Comments
The webdataset pipeline doesn't have access to the string classnames at that point, it's using integer indicies from the get go, so the map capability is pretty minimal it can only really remap integers, not even filter out. I'd have to add specific filtering/remap functionality but not sure it's easy to do in a generic way without looking closer, would need to use other metadata to get the classname (sysnset) ... |
so you mean now i cannot use class map for wds now. if not what can i do? plz tell me solution in detail :) i just wanna class pruning... |
Quickest path would be to hack the _decode function at this line Decode the json there, get the class_name from the json, and then do a lookup on your class map there (have to make the map accessible there), return None before image decode if the class isn't in the map (this will skip the sample), and overwrite the class_label with mapped value if it is there. So not too crazy but not trivial either. |
okay i will try soon and report to you. |
hi rwightman. i added the code what you did say, i ran the train.py. but that have had so ling running time. (None class map => 2 hours on 1 epoch, w/ class map => 5days on 1 epoch) what should i do? ( i attach the reader_wds.py) ( the code in timm/data/readers/reader_wds.py 157.line) if class_map.endswith(".txt"):
|
@TheDarkKnight-21th that's going to be extremely slow, you're loading the same mapping file every sample. You'd want to add a class_to_idx argument to the _decoder fn. If it's a valid map (dict w/ entries), execute the remap block, you also want to remap the actual label not just filter out valid, because you'd; want to collapse the label space it consecutive indices in most use cases. In the |
yeah you are right. i chagned the code that you said. i just added class_to_idx on _decoder and also added the argument on pipline partial. but still it is so slow. (1epoch => 3.5days) (i add this code also) i also post the reader_wds.py in reader_wds.zip. plz unzip this file. |
With sharded datasets there is no way of knowing what samples are still valid due to filtering, so there is no way of knowing the dataset length without calculating yourself. You have to provide a new number of samples as estimate for the filtered dataset. Now with that in mind, since it will continue to use the same # of samples and thus steps, if you are significnatly filtering, ie only taking a small % of the classes it willl be exrremely inefficient because you have to iterate over ALL the samples in the shard to get the ones you want, so you have to read all that data still (you just avoid the decode), but likely to slow your dataloading per samples extracted. If you took say 100 of 19000 classes you'd have to pass over the dataset ~190x to get the same # of samples. So you probably want to evaluate your motivation of doing this. If you want to say use 50% of the classes (roughly evenly distributed in frequency) that'd be okay, if you want to say use a few thousand or of the least frequent classes this will be inefficient. And your CPU + disk throughput will be determining where the limit is... |
To summarize what you have been saying, if you proceed as mentioned in the quoted text above, i can selectively train data sample using a class_map, but due to reasons mentioned above (e.g., repeated processes for sampling, limited resource issues), the training could be slow. |
Describe the bug
A clear and concise description of what the bug is.
i wanted to train the the parts of IN21k-winter class , so i made the class map of IN21k-winter and i run training.
but there is the error in timm/train.py when i use the webdataset (timm/imagent-w21-wds) with class map
This error is "KEYERROR".
when i use the class map with Image Folder(dataset is original IN21k-winter, not wds) , there is not any error.
what should do i run the training scripts with class map of timm/imagenet-w21-wds?
(plz check the pytorch-image-models/timm/data/readers/reader_wds.py)
To Reproduce
Steps to reproduce the behavior:
Expected behavior
A clear and concise description of what you expected to happen.
i just wanna run the training scripts with class map of timm/imagenet-w21-wds
Screenshots
If applicable, add screenshots to help explain your problem.
<class map example (.txt)>
<"error">
Desktop (please complete the following information):
conda list
, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0]Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: