Skip to content

Commit 425cbbd

Browse files
Changed sortedlist.py to a shorter, faster and more user friendly version
1 parent ce94312 commit 425cbbd

File tree

1 file changed

+110
-221
lines changed

1 file changed

+110
-221
lines changed

pyrival/data_structures/SortedList.py

Lines changed: 110 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -1,233 +1,122 @@
1-
class SortedList:
2-
def __init__(self, iterable=[], _load=200):
3-
"""Initialize sorted list instance."""
4-
values = sorted(iterable)
5-
self._len = _len = len(values)
6-
self._load = _load
7-
self._lists = _lists = [values[i:i + _load] for i in range(0, _len, _load)]
8-
self._list_lens = [len(_list) for _list in _lists]
9-
self._mins = [_list[0] for _list in _lists]
10-
self._fen_tree = []
11-
self._rebuild = True
12-
13-
def _fen_build(self):
14-
"""Build a fenwick tree instance."""
15-
self._fen_tree[:] = self._list_lens
16-
_fen_tree = self._fen_tree
17-
for i in range(len(_fen_tree)):
18-
if i | i + 1 < len(_fen_tree):
19-
_fen_tree[i | i + 1] += _fen_tree[i]
20-
self._rebuild = False
21-
22-
def _fen_update(self, index, value):
23-
"""Update `fen_tree[index] += value`."""
24-
if not self._rebuild:
25-
_fen_tree = self._fen_tree
26-
while index < len(_fen_tree):
27-
_fen_tree[index] += value
28-
index |= index + 1
29-
30-
def _fen_query(self, end):
31-
"""Return `sum(_fen_tree[:end])`."""
32-
if self._rebuild:
33-
self._fen_build()
34-
35-
_fen_tree = self._fen_tree
1+
"""
2+
The "sorted list" data-structure, with amortized O(n^(1/3)) cost per insert and pop.
3+
4+
Example:
5+
6+
A = SortedList()
7+
A.insert(30)
8+
A.insert(50)
9+
A.insert(20)
10+
A.insert(30)
11+
A.insert(30)
12+
13+
print(A) # prints [20, 30, 30, 30, 50]
14+
15+
print(A.lower_bound(30), A.upper_bound(30)) # prints 1 4
16+
17+
print(A[-1]) # prints 50
18+
print(A.pop(1)) # prints 30
19+
20+
print(A) # prints [20, 30, 30, 50]
21+
print(A.count(30)) # prints 2
22+
23+
"""
24+
25+
from bisect import bisect_left as lower_bound, bisect_right as upper_bound
26+
27+
class FenwickTree:
28+
def __init__(self, x):
29+
bit = self.bit = list(x)
30+
size = self.size = len(bit)
31+
for i in range(size):
32+
j = i | (i + 1)
33+
if j < size:
34+
bit[j] += bit[i]
35+
36+
def update(self, idx, x):
37+
"""updates bit[idx] += x"""
38+
while idx < self.size:
39+
self.bit[idx] += x
40+
idx |= idx + 1
41+
42+
def __call__(self, end):
43+
"""calc sum(bit[:end])"""
3644
x = 0
3745
while end:
38-
x += _fen_tree[end - 1]
46+
x += self.bit[end - 1]
3947
end &= end - 1
4048
return x
4149

42-
def _fen_findkth(self, k):
43-
"""Return a pair of (the largest `idx` such that `sum(_fen_tree[:idx]) <= k`, `k - sum(_fen_tree[:idx])`)."""
44-
_list_lens = self._list_lens
45-
if k < _list_lens[0]:
46-
return 0, k
47-
if k >= self._len - _list_lens[-1]:
48-
return len(_list_lens) - 1, k + _list_lens[-1] - self._len
49-
if self._rebuild:
50-
self._fen_build()
51-
52-
_fen_tree = self._fen_tree
50+
def find_kth(self, k):
51+
"""Find largest idx such that sum(bit[:idx]) <= k"""
5352
idx = -1
54-
for d in reversed(range(len(_fen_tree).bit_length())):
53+
for d in reversed(range(self.size.bit_length())):
5554
right_idx = idx + (1 << d)
56-
if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]:
55+
if right_idx < self.size and self.bit[right_idx] <= k:
5756
idx = right_idx
58-
k -= _fen_tree[idx]
57+
k -= self.bit[idx]
5958
return idx + 1, k
60-
61-
def _delete(self, pos, idx):
62-
"""Delete value at the given `(pos, idx)`."""
63-
_lists = self._lists
64-
_mins = self._mins
65-
_list_lens = self._list_lens
66-
67-
self._len -= 1
68-
self._fen_update(pos, -1)
69-
del _lists[pos][idx]
70-
_list_lens[pos] -= 1
71-
72-
if _list_lens[pos]:
73-
_mins[pos] = _lists[pos][0]
74-
else:
75-
del _lists[pos]
76-
del _list_lens[pos]
77-
del _mins[pos]
78-
self._rebuild = True
79-
80-
def _loc_left(self, value):
81-
"""Return an index pair that corresponds to the first position of `value` in the sorted list."""
82-
if not self._len:
83-
return 0, 0
84-
85-
_lists = self._lists
86-
_mins = self._mins
87-
88-
lo, pos = -1, len(_lists) - 1
89-
while lo + 1 < pos:
90-
mi = (lo + pos) >> 1
91-
if value <= _mins[mi]:
92-
pos = mi
93-
else:
94-
lo = mi
95-
96-
if pos and value <= _lists[pos - 1][-1]:
97-
pos -= 1
98-
99-
_list = _lists[pos]
100-
lo, idx = -1, len(_list)
101-
while lo + 1 < idx:
102-
mi = (lo + idx) >> 1
103-
if value <= _list[mi]:
104-
idx = mi
105-
else:
106-
lo = mi
107-
108-
return pos, idx
109-
110-
def _loc_right(self, value):
111-
"""Return an index pair that corresponds to the last position of `value` in the sorted list."""
112-
if not self._len:
113-
return 0, 0
114-
115-
_lists = self._lists
116-
_mins = self._mins
117-
118-
pos, hi = 0, len(_lists)
119-
while pos + 1 < hi:
120-
mi = (pos + hi) >> 1
121-
if value < _mins[mi]:
122-
hi = mi
123-
else:
124-
pos = mi
125-
126-
_list = _lists[pos]
127-
lo, idx = -1, len(_list)
128-
while lo + 1 < idx:
129-
mi = (lo + idx) >> 1
130-
if value < _list[mi]:
131-
idx = mi
132-
else:
133-
lo = mi
134-
135-
return pos, idx
136-
137-
def add(self, value):
138-
"""Add `value` to sorted list."""
139-
_load = self._load
140-
_lists = self._lists
141-
_mins = self._mins
142-
_list_lens = self._list_lens
143-
144-
self._len += 1
145-
if _lists:
146-
pos, idx = self._loc_right(value)
147-
self._fen_update(pos, 1)
148-
_list = _lists[pos]
149-
_list.insert(idx, value)
150-
_list_lens[pos] += 1
151-
_mins[pos] = _list[0]
152-
if _load + _load < len(_list):
153-
_lists.insert(pos + 1, _list[_load:])
154-
_list_lens.insert(pos + 1, len(_list) - _load)
155-
_mins.insert(pos + 1, _list[_load])
156-
_list_lens[pos] = _load
157-
del _list[_load:]
158-
self._rebuild = True
159-
else:
160-
_lists.append([value])
161-
_mins.append(value)
162-
_list_lens.append(1)
163-
self._rebuild = True
164-
165-
def discard(self, value):
166-
"""Remove `value` from sorted list if it is a member."""
167-
_lists = self._lists
168-
if _lists:
169-
pos, idx = self._loc_right(value)
170-
if idx and _lists[pos][idx - 1] == value:
171-
self._delete(pos, idx - 1)
172-
173-
def remove(self, value):
174-
"""Remove `value` from sorted list; `value` must be a member."""
175-
_len = self._len
176-
self.discard(value)
177-
if _len == self._len:
178-
raise ValueError('{0!r} not in list'.format(value))
179-
180-
def pop(self, index=-1):
181-
"""Remove and return value at `index` in sorted list."""
182-
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
183-
value = self._lists[pos][idx]
184-
self._delete(pos, idx)
185-
return value
186-
187-
def bisect_left(self, value):
188-
"""Return the first index to insert `value` in the sorted list."""
189-
pos, idx = self._loc_left(value)
190-
return self._fen_query(pos) + idx
191-
192-
def bisect_right(self, value):
193-
"""Return the last index to insert `value` in the sorted list."""
194-
pos, idx = self._loc_right(value)
195-
return self._fen_query(pos) + idx
196-
197-
def count(self, value):
198-
"""Return number of occurrences of `value` in the sorted list."""
199-
return self.bisect_right(value) - self.bisect_left(value)
200-
59+
60+
61+
block_size = 700
62+
class SortedList:
63+
def __init__(self):
64+
self.macro = []
65+
self.micros = [[]]
66+
self.micro_size = [0]
67+
self.fenwick = FenwickTree([0])
68+
self.size = 0
69+
70+
def insert(self, x):
71+
i = lower_bound(self.macro, x)
72+
j = upper_bound(self.micros[i], x)
73+
self.micros[i].insert(j, x)
74+
self.size += 1
75+
self.micro_size[i] += 1
76+
self.fenwick.update(i, 1)
77+
if len(self.micros[i]) >= block_size:
78+
self.micros[i : i + 1] = self.micros[i][:block_size >> 1], self.micros[i][block_size >> 1:]
79+
self.micro_size[i : i + 1] = block_size >> 1, block_size >> 1
80+
self.fenwick = FenwickTree(self.micro_size)
81+
self.macro.insert(i, self.micros[i + 1][0])
82+
83+
def pop(self, k=0):
84+
i,j = self._find_kth(k)
85+
x = self.micros[i].pop(j)
86+
self.size -= 1
87+
self.micro_size[i] -= 1
88+
self.fenwick.update(i, -1)
89+
return x
90+
91+
def __getitem__(self, k):
92+
i,j = self._find_kth(k)
93+
return self.micros[i][j]
94+
95+
def __contains__(self, x):
96+
i = lower_bound(self.macro, x)
97+
j = lower_bound(self.micros[i], x)
98+
i,j = (i, j) if j < self.micro_size[i] else (i + 1, 0)
99+
return i < len(self.micros) and j < self.micros_size[i] and self.micros[i][j] == x
100+
101+
def lower_bound(self, x):
102+
i = lower_bound(self.macro, x)
103+
return self.fenwick(i) + lower_bound(self.micros[i], x)
104+
105+
def upper_bound(self, x):
106+
i = upper_bound(self.macro, x)
107+
return self.fenwick(i) + upper_bound(self.micros[i], x)
108+
109+
def _find_kth(self, k):
110+
return self.fenwick.find_kth(k + self.size if k < 0 else k)
111+
201112
def __len__(self):
202-
"""Return the size of the sorted list."""
203-
return self._len
204-
205-
def __getitem__(self, index):
206-
"""Lookup value at `index` in sorted list."""
207-
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
208-
return self._lists[pos][idx]
209-
210-
def __delitem__(self, index):
211-
"""Remove value at `index` from sorted list."""
212-
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
213-
self._delete(pos, idx)
214-
215-
def __contains__(self, value):
216-
"""Return true if `value` is an element of the sorted list."""
217-
_lists = self._lists
218-
if _lists:
219-
pos, idx = self._loc_left(value)
220-
return idx < len(_lists[pos]) and _lists[pos][idx] == value
221-
return False
222-
113+
return self.size
114+
223115
def __iter__(self):
224-
"""Return an iterator over the sorted list."""
225-
return (value for _list in self._lists for value in _list)
226-
227-
def __reversed__(self):
228-
"""Return a reverse iterator over the sorted list."""
229-
return (value for _list in reversed(self._lists) for value in reversed(_list))
230-
116+
return (x for micro in self.micros for x in micro)
117+
231118
def __repr__(self):
232-
"""Return string representation of sorted list."""
233-
return 'SortedList({0})'.format(list(self))
119+
return str(list(self))
120+
121+
def count(self, x):
122+
return self.upper_bound(x) - self.lower_bound(x)

0 commit comments

Comments
 (0)