Do a better RandomizedSet

This commit is contained in:
Dan Buch 2023-10-26 03:58:50 -04:00
parent c8a4928ee8
commit c274f59f58
Signed by: meatballhat
GPG Key ID: A12F782281063434
2 changed files with 69 additions and 6 deletions

View File

@ -396,7 +396,7 @@ def h_index(citations: list[int]) -> int:
return last_qualified or 0 return last_qualified or 0
class RandomizedSet: class SlowRandomizedSet:
def __init__(self): def __init__(self):
self._i: set[int] = set() self._i: set[int] = set()
@ -412,4 +412,36 @@ class RandomizedSet:
return False return False
def getRandom(self) -> int: def getRandom(self) -> int:
return next(iter(sorted(list(self._i), key=lambda _: random.random()))) return random.choice(list(self._i))
class RandomizedSet:
def __init__(self):
self._l: list[int] = []
self._m: dict[int, int] = {}
def insert(self, val: int) -> bool:
if val in self._m:
return False
self._m[val] = len(self._l)
self._l.append(val)
return True
def remove(self, val: int) -> bool:
if val not in self._m:
return False
val_loc = self._m[val]
last_val = self._l[-1]
self._l[val_loc] = last_val
self._m[last_val] = val_loc
self._l.pop()
self._m.pop(val)
return True
def getRandom(self) -> int:
return random.choice(self._l)

View File

@ -333,8 +333,15 @@ def test_h_index(citations: list[int], expected: int):
assert stuff.h_index(citations) == expected assert stuff.h_index(citations) == expected
def test_randomized_set(): @pytest.mark.parametrize(
inst = stuff.RandomizedSet() ("cls",),
[
(stuff.SlowRandomizedSet,),
(stuff.RandomizedSet,),
],
)
def test_randomized_set(cls: type[stuff.RandomizedSet] | type[stuff.SlowRandomizedSet]):
inst = cls()
assert inst.insert(1) is True assert inst.insert(1) is True
assert inst.remove(2) is False assert inst.remove(2) is False
@ -344,7 +351,7 @@ def test_randomized_set():
assert inst.insert(2) is False assert inst.insert(2) is False
assert inst.getRandom() == 2 assert inst.getRandom() == 2
inst = stuff.RandomizedSet() inst = cls()
assert inst.insert(1) is True assert inst.insert(1) is True
assert inst.insert(10) is True assert inst.insert(10) is True
@ -353,7 +360,31 @@ def test_randomized_set():
seen: set[int] = set() seen: set[int] = set()
for _ in range(100_000): for _ in range(10_000):
seen.add(inst.getRandom()) seen.add(inst.getRandom())
assert seen == {1, 10, 20, 30} assert seen == {1, 10, 20, 30}
# ["remove","remove","insert","getRandom","remove","insert"]
# [[0],[0],[0],[],[0],[0]]
inst = cls()
assert inst.remove(0) is False
assert inst.remove(0) is False
assert inst.insert(0) is True
assert inst.getRandom() == 0
assert inst.remove(0) is True
assert inst.insert(0) is True
# ["RandomizedSet","insert","insert","remove","insert","remove","getRandom"]
# [[],[0],[1],[0],[2],[1],[]]
inst = cls()
assert inst.insert(0) is True
assert inst.insert(1) is True
assert inst.remove(0) is True
assert inst.insert(2) is True
assert inst.remove(1) is True
assert inst.getRandom() == 2