-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
70 lines (53 loc) · 2.06 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import spaces
import gradio as gr
import cv2
import numpy as np
import time
import random
from PIL import Image
import torch
torch.jit.script = lambda f: f
from transparent_background import Remover
@spaces.GPU()
def doo(video, mode, progress=gr.Progress()):
if mode == 'Fast':
remover = Remover(mode='fast')
else:
remover = Remover()
cap = cv2.VideoCapture(video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Get total frames
writer = None
tmpname = random.randint(111111111, 999999999)
processed_frames = 0
start_time = time.time()
while cap.isOpened():
ret, frame = cap.read()
if ret is False:
break
if time.time() - start_time >= 20 * 60 - 5:
print("GPU Timing Out")
cap.release()
writer.release()
return str(tmpname) + '.mp4'
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame).convert('RGB')
if writer is None:
writer = cv2.VideoWriter(str(tmpname) + '.mp4', cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)
processed_frames += 1
print(f"Processing frame {processed_frames}")
progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
out = remover.process(img, type='green')
writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB))
cap.release()
writer.release()
return str(tmpname) + '.mp4'
description="Bigger the file size, Longer the time takes. May got GPU timeout ( Abort / Error )"
examples = [['./input2.mp4'],['./input.mp4']]
iface = gr.Interface(
fn=doo,
inputs=["video", gr.components.Radio(['Standard', 'Quick'], label='Select mode', value='Normal', info='Standard is more accurate but takes longer⏪, while quick is quicker but less accurate.⏩')],
outputs="video",
examples=examples,
description=description
)
iface.launch()