diff --git a/leetcode/stuff.py b/leetcode/stuff.py index d98704a..c1633cc 100644 --- a/leetcode/stuff.py +++ b/leetcode/stuff.py @@ -396,7 +396,7 @@ def h_index(citations: list[int]) -> int: return last_qualified or 0 -class RandomizedSet: +class SlowRandomizedSet: def __init__(self): self._i: set[int] = set() @@ -412,4 +412,36 @@ class RandomizedSet: return False def getRandom(self) -> int: - return next(iter(sorted(list(self._i), key=lambda _: random.random()))) + return random.choice(list(self._i)) + + +class RandomizedSet: + def __init__(self): + self._l: list[int] = [] + self._m: dict[int, int] = {} + + def insert(self, val: int) -> bool: + if val in self._m: + return False + + self._m[val] = len(self._l) + self._l.append(val) + + return True + + def remove(self, val: int) -> bool: + if val not in self._m: + return False + + val_loc = self._m[val] + last_val = self._l[-1] + self._l[val_loc] = last_val + self._m[last_val] = val_loc + + self._l.pop() + self._m.pop(val) + + return True + + def getRandom(self) -> int: + return random.choice(self._l) diff --git a/leetcode/test_stuff.py b/leetcode/test_stuff.py index 3730f43..11ef6ff 100644 --- a/leetcode/test_stuff.py +++ b/leetcode/test_stuff.py @@ -333,8 +333,15 @@ def test_h_index(citations: list[int], expected: int): assert stuff.h_index(citations) == expected -def test_randomized_set(): - inst = stuff.RandomizedSet() +@pytest.mark.parametrize( + ("cls",), + [ + (stuff.SlowRandomizedSet,), + (stuff.RandomizedSet,), + ], +) +def test_randomized_set(cls: type[stuff.RandomizedSet] | type[stuff.SlowRandomizedSet]): + inst = cls() assert inst.insert(1) is True assert inst.remove(2) is False @@ -344,7 +351,7 @@ def test_randomized_set(): assert inst.insert(2) is False assert inst.getRandom() == 2 - inst = stuff.RandomizedSet() + inst = cls() assert inst.insert(1) is True assert inst.insert(10) is True @@ -353,7 +360,31 @@ def test_randomized_set(): seen: set[int] = set() - for _ in range(100_000): + for _ in range(10_000): seen.add(inst.getRandom()) assert seen == {1, 10, 20, 30} + + # ["remove","remove","insert","getRandom","remove","insert"] + # [[0],[0],[0],[],[0],[0]] + + inst = cls() + + assert inst.remove(0) is False + assert inst.remove(0) is False + assert inst.insert(0) is True + assert inst.getRandom() == 0 + assert inst.remove(0) is True + assert inst.insert(0) is True + + # ["RandomizedSet","insert","insert","remove","insert","remove","getRandom"] + # [[],[0],[1],[0],[2],[1],[]] + + inst = cls() + + assert inst.insert(0) is True + assert inst.insert(1) is True + assert inst.remove(0) is True + assert inst.insert(2) is True + assert inst.remove(1) is True + assert inst.getRandom() == 2