forked from ktrk115/const_layout
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rico.py
116 lines (93 loc) · 3.29 KB
/
rico.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
from pathlib import Path
import torch
from torch_geometric.data import Data
from data.base import BaseDataset
def append_child(element, elements):
if 'children' in element.keys():
for child in element['children']:
elements.append(child)
elements = append_child(child, elements)
return elements
class Rico(BaseDataset):
labels = [
'Toolbar',
'Image',
'Text',
'Icon',
'Text Button',
'Input',
'List Item',
'Advertisement',
'Pager Indicator',
'Web View',
'Background Image',
'Drawer',
'Modal',
]
def __init__(self, split='train', transform=None):
super().__init__('rico', split, transform)
def download(self):
super().download()
def process(self):
data_list = []
raw_dir = Path(self.raw_dir) / 'semantic_annotations'
for json_path in sorted(raw_dir.glob('*.json')):
with json_path.open() as f:
ann = json.load(f)
B = ann['bounds']
W, H = float(B[2]), float(B[3])
if B[0] != 0 or B[1] != 0 or H < W:
continue
def is_valid(element):
if element['componentLabel'] not in set(self.labels):
return False
x1, y1, x2, y2 = element['bounds']
if x1 < 0 or y1 < 0 or W < x2 or H < y2:
return False
if x2 <= x1 or y2 <= y1:
return False
return True
elements = append_child(ann, [])
_elements = list(filter(is_valid, elements))
filtered = len(elements) != len(_elements)
elements = _elements
N = len(elements)
if N == 0 or 9 < N:
continue
boxes = []
labels = []
for element in elements:
# bbox
x1, y1, x2, y2 = element['bounds']
xc = (x1 + x2) / 2.
yc = (y1 + y2) / 2.
width = x2 - x1
height = y2 - y1
b = [xc / W, yc / H,
width / W, height / H]
boxes.append(b)
# label
l = element['componentLabel']
labels.append(self.label2index[l])
boxes = torch.tensor(boxes, dtype=torch.float)
labels = torch.tensor(labels, dtype=torch.long)
data = Data(x=boxes, y=labels)
data.attr = {
'name': json_path.name,
'width': W,
'height': H,
'filtered': filtered,
'has_canvas_element': False,
}
data_list.append(data)
# shuffle with seed
generator = torch.Generator().manual_seed(0)
indices = torch.randperm(len(data_list), generator=generator)
data_list = [data_list[i] for i in indices]
# train 85% / val 5% / test 10%
N = len(data_list)
s = [int(N * .85), int(N * .90)]
torch.save(self.collate(data_list[:s[0]]), self.processed_paths[0])
torch.save(self.collate(data_list[s[0]:s[1]]), self.processed_paths[1])
torch.save(self.collate(data_list[s[1]:]), self.processed_paths[2])