-
-
Notifications
You must be signed in to change notification settings - Fork 41
/
utils.py
253 lines (232 loc) · 9.18 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import hashlib
import os
from typing import Iterable
import shutil
import subprocess
import re
from collections.abc import Mapping
import torch
import server
BIGMIN = -(2**53-1)
BIGMAX = (2**53-1)
DIMMAX = 8192
def ffmpeg_suitability(path):
try:
version = subprocess.run([path, "-version"], check=True,
capture_output=True).stdout.decode("utf-8")
except:
return 0
score = 0
#rough layout of the importance of various features
simple_criterion = [("libvpx", 20),("264",10), ("265",3),
("svtav1",5),("libopus", 1)]
for criterion in simple_criterion:
if version.find(criterion[0]) >= 0:
score += criterion[1]
#obtain rough compile year from copyright information
copyright_index = version.find('2000-2')
if copyright_index >= 0:
copyright_year = version[copyright_index+6:copyright_index+9]
if copyright_year.isnumeric():
score += int(copyright_year)
return score
if "VHS_FORCE_FFMPEG_PATH" in os.environ:
ffmpeg_path = os.environ.get("VHS_FORCE_FFMPEG_PATH")
else:
ffmpeg_paths = []
try:
from imageio_ffmpeg import get_ffmpeg_exe
imageio_ffmpeg_path = get_ffmpeg_exe()
ffmpeg_paths.append(imageio_ffmpeg_path)
except:
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
raise
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
ffmpeg_path = imageio_ffmpeg_path
else:
system_ffmpeg = shutil.which("ffmpeg")
if system_ffmpeg is not None:
ffmpeg_paths.append(system_ffmpeg)
if os.path.isfile("ffmpeg"):
ffmpeg_paths.append(os.path.abspath("ffmpeg"))
if os.path.isfile("ffmpeg.exe"):
ffmpeg_paths.append(os.path.abspath("ffmpeg.exe"))
if len(ffmpeg_paths) == 0:
ffmpeg_path = None
elif len(ffmpeg_paths) == 1:
#Evaluation of suitability isn't required, can take sole option
#to reduce startup time
ffmpeg_path = ffmpeg_paths[0]
else:
ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability)
gifski_path = os.environ.get("VHS_GIFSKI", None)
if gifski_path is None:
gifski_path = os.environ.get("JOV_GIFSKI", None)
if gifski_path is None:
gifski_path = shutil.which("gifski")
def is_safe_path(path):
if "VHS_STRICT_PATHS" not in os.environ:
return True
basedir = os.path.abspath('.')
try:
common_path = os.path.commonpath([basedir, path])
except:
#Different drive on windows
return False
return common_path == basedir
def get_sorted_dir_files_from_directory(directory: str, skip_first_images: int=0, select_every_nth: int=1, extensions: Iterable=None):
directory = strip_path(directory)
dir_files = os.listdir(directory)
dir_files = sorted(dir_files)
dir_files = [os.path.join(directory, x) for x in dir_files]
dir_files = list(filter(lambda filepath: os.path.isfile(filepath), dir_files))
# filter by extension, if needed
if extensions is not None:
extensions = list(extensions)
new_dir_files = []
for filepath in dir_files:
ext = "." + filepath.split(".")[-1]
if ext.lower() in extensions:
new_dir_files.append(filepath)
dir_files = new_dir_files
# start at skip_first_images
dir_files = dir_files[skip_first_images:]
dir_files = dir_files[0::select_every_nth]
return dir_files
# modified from https://stackoverflow.com/questions/22058048/hashing-a-file-in-python
def calculate_file_hash(filename: str, hash_every_n: int = 1):
#Larger video files were taking >.5 seconds to hash even when cached,
#so instead the modified time from the filesystem is used as a hash
h = hashlib.sha256()
h.update(filename.encode())
h.update(str(os.path.getmtime(filename)).encode())
return h.hexdigest()
prompt_queue = server.PromptServer.instance.prompt_queue
def requeue_workflow_unchecked():
"""Requeues the current workflow without checking for multiple requeues"""
currently_running = prompt_queue.currently_running
(_, _, prompt, extra_data, outputs_to_execute) = next(iter(currently_running.values()))
#Ensure batch_managers are marked stale
prompt = prompt.copy()
for uid in prompt:
if prompt[uid]['class_type'] == 'VHS_BatchManager':
prompt[uid]['inputs']['requeue'] = prompt[uid]['inputs'].get('requeue',0)+1
#execution.py has guards for concurrency, but server doesn't.
#TODO: Check that this won't be an issue
number = -server.PromptServer.instance.number
server.PromptServer.instance.number += 1
prompt_id = str(server.uuid.uuid4())
prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
requeue_guard = [None, 0, 0, {}]
def requeue_workflow(requeue_required=(-1,True)):
assert(len(prompt_queue.currently_running) == 1)
global requeue_guard
(run_number, _, prompt, _, _) = next(iter(prompt_queue.currently_running.values()))
if requeue_guard[0] != run_number:
#Calculate a count of how many outputs are managed by a batch manager
managed_outputs=0
for bm_uid in prompt:
if prompt[bm_uid]['class_type'] == 'VHS_BatchManager':
for output_uid in prompt:
if prompt[output_uid]['class_type'] in ["VHS_VideoCombine"]:
for inp in prompt[output_uid]['inputs'].values():
if inp == [bm_uid, 0]:
managed_outputs+=1
requeue_guard = [run_number, 0, managed_outputs, {}]
requeue_guard[1] = requeue_guard[1]+1
requeue_guard[3][requeue_required[0]] = requeue_required[1]
if requeue_guard[1] == requeue_guard[2] and max(requeue_guard[3].values()):
requeue_workflow_unchecked()
def get_audio(file, start_time=0, duration=0):
args = [ffmpeg_path, "-i", file]
if start_time > 0:
args += ["-ss", str(start_time)]
if duration > 0:
args += ["-t", str(duration)]
try:
#TODO: scan for sample rate and maintain
res = subprocess.run(args + ["-f", "f32le", "-"],
capture_output=True, check=True)
audio = torch.frombuffer(bytearray(res.stdout), dtype=torch.float32)
except subprocess.CalledProcessError as e:
audio = torch.zeros(1,2)
return {'waveform': audio, 'sample_rate': '16000'}
match = re.search(', (\\d+) Hz, (\\w+), ',res.stderr.decode('utf-8'))
if match:
ar = int(match.group(1))
#NOTE: Just throwing an error for other channel types right now
#Will deal with issues if they come
ac = {"mono": 1, "stereo": 2}[match.group(2)]
else:
ar = 44100
ac = 2
audio = audio.reshape((-1,ac)).transpose(0,1).unsqueeze(0)
return {'waveform': audio, 'sample_rate': ar}
class LazyAudioMap(Mapping):
def __init__(self, file, start_time, duration):
self.file = file
self.start_time=start_time
self.duration=duration
self._dict=None
def __getitem__(self, key):
if self._dict is None:
self._dict = get_audio(self.file, self.start_time, self.duration)
return self._dict[key]
def __iter__(self):
if self._dict is None:
self._dict = get_audio(self.file, self.start_time, self.duration)
return iter(self._dict)
def __len__(self):
if self._dict is None:
self._dict = get_audio(self.file, self.start_time, self.duration)
return len(self._dict)
def lazy_eval(file, start_time=0, duration=0):
return LazyAudioMap(file, start_time, duration)
def is_url(url):
return url.split("://")[0] in ["http", "https"]
def validate_sequence(path):
#Check if path is a valid ffmpeg sequence that points to at least one file
(path, file) = os.path.split(path)
if not os.path.isdir(path):
return False
match = re.search('%0?\\d+d', file)
if not match:
return False
seq = match.group()
if seq == '%d':
seq = '\\\\d+'
else:
seq = '\\\\d{%s}' % seq[1:-1]
file_matcher = re.compile(re.sub('%0?\\d+d', seq, file))
for file in os.listdir(path):
if file_matcher.fullmatch(file):
return True
return False
def strip_path(path):
#This leaves whitespace inside quotes and only a single "
#thus ' ""test"' -> '"test'
#consider path.strip(string.whitespace+"\"")
#or weightier re.fullmatch("[\\s\"]*(.+?)[\\s\"]*", path).group(1)
path = path.strip()
if path.startswith("\""):
path = path[1:]
if path.endswith("\""):
path = path[:-1]
return path
def hash_path(path):
if path is None:
return "input"
if is_url(path):
return "url"
return calculate_file_hash(strip_path(path))
def validate_path(path, allow_none=False, allow_url=True):
if path is None:
return allow_none
if is_url(path):
#Probably not feasible to check if url resolves here
if not allow_url:
return "URLs are unsupported for this path"
return is_safe_path(path)
if not os.path.isfile(strip_path(path)):
return "Invalid file path: {}".format(path)
return is_safe_path(path)