diff --git a/leetcode/stdlib.py b/leetcode/stdlib.py index cafdb8d..e307202 100644 --- a/leetcode/stdlib.py +++ b/leetcode/stdlib.py @@ -1,6 +1,11 @@ import typing +class LinkedListNode(typing.Protocol): + val: int + next: typing.Optional["LinkedListNode"] + + class ListNode: """ListNode is the leetcode "standard library" type used in linked lists""" @@ -9,8 +14,67 @@ class ListNode: self.next = next +class BinaryTreeNode(typing.Protocol): + val: int + left: typing.Optional["BinaryTreeNode"] + right: typing.Optional["BinaryTreeNode"] + + +class TreeNode: + """TreeNode is the leetcode "standard library" type used in binary trees""" + + def __init__( + self, + val: int = 0, + left: typing.Optional["TreeNode"] = None, + right: typing.Optional["TreeNode"] = None, + ): + self.val = val + self.left = left + self.right = right + + @classmethod + def from_int(cls, val: int | None) -> typing.Optional["TreeNode"]: + if val is None: + return None + + return TreeNode(val) + + # __repr__ was added by me + def __repr__(self) -> str: + filtered_parts = [] + + for key, value in [ + ("val", self.val), + ("right", self.right), + ("left", self.left), + ]: + if value is not None: + filtered_parts.append((key, value)) + + middle = ", ".join([f"{k}={v!r}" for k, v in filtered_parts]) + + return f"TreeNode({middle})" + + # __eq__ was added by me + def __eq__(self, other: typing.Optional["TreeNode"]) -> bool: + return ( + other is not None + and self.val == other.val + and self.left == other.left + and self.right == other.right + ) + + +class ConnectableBinaryTreeNode(typing.Protocol): + val: int + left: typing.Optional["BinaryTreeNode"] + right: typing.Optional["BinaryTreeNode"] + next: typing.Optional["BinaryTreeNode"] + + class Node: - """Node is the leetcode "standard library" type used in binary trees""" + """Node is the *other* leetcode "standard library" type used in binary trees""" def __init__( self, diff --git a/leetcode/stuff.py b/leetcode/stuff.py index a6b5cd7..ff5b1f9 100644 --- a/leetcode/stuff.py +++ b/leetcode/stuff.py @@ -163,7 +163,7 @@ class MinStack: return self._min[-1] -def linked_list_to_list(head: stdlib.ListNode | None) -> list[int]: +def linked_list_to_list(head: stdlib.LinkedListNode | None) -> list[int]: seen: set[int] = set() ret: list[int] = [] @@ -178,9 +178,9 @@ def linked_list_to_list(head: stdlib.ListNode | None) -> list[int]: return ret -def sort_linked_list(head: stdlib.ListNode | None) -> stdlib.ListNode | None: - by_val: list[tuple[int, stdlib.ListNode]] = [] - ret: stdlib.ListNode | None = None +def sort_linked_list(head: stdlib.LinkedListNode | None) -> stdlib.LinkedListNode | None: + by_val: list[tuple[int, stdlib.LinkedListNode]] = [] + ret: stdlib.LinkedListNode | None = None while head is not None: by_val.append((head.val, head)) @@ -203,12 +203,13 @@ def sort_linked_list(head: stdlib.ListNode | None) -> stdlib.ListNode | None: def connect_binary_tree_right( - root: stdlib.Node | None, -) -> tuple[stdlib.Node | None, list[int | None]]: + root: stdlib.ConnectableBinaryTreeNode | None, +) -> tuple[stdlib.ConnectableBinaryTreeNode | None, list[int | None]]: if root is None: return None, [] by_level = binary_tree_by_level(copy.deepcopy(root)) + by_level = typing.cast(dict[int, list[stdlib.ConnectableBinaryTreeNode]], by_level) serialized: list[int | None] = [] print("") @@ -231,8 +232,10 @@ def connect_binary_tree_right( return connected_root, serialized -def binary_tree_by_level(root: stdlib.Node) -> dict[int, list[stdlib.Node]]: - combined: dict[int, list[stdlib.Node]] = {} +def binary_tree_by_level( + root: stdlib.BinaryTreeNode, +) -> dict[int, list[stdlib.BinaryTreeNode]]: + combined: dict[int, list[stdlib.BinaryTreeNode]] = {} for path in collect_binary_tree_levels(0, root): level, node = path @@ -243,8 +246,8 @@ def binary_tree_by_level(root: stdlib.Node) -> dict[int, list[stdlib.Node]]: def collect_binary_tree_levels( - level: int, node: stdlib.Node | None -) -> typing.Iterator[tuple[int, stdlib.Node]]: + level: int, node: stdlib.BinaryTreeNode | None +) -> typing.Iterator[tuple[int, stdlib.BinaryTreeNode]]: if node is None: return @@ -253,7 +256,7 @@ def collect_binary_tree_levels( yield from collect_binary_tree_levels(level + 1, node.left) -def sum_binary_tree_path_ints(root: stdlib.Node | None) -> int: +def sum_binary_tree_path_ints(root: stdlib.BinaryTreeNode | None) -> int: path_ints: list[int] = [] for path in collect_binary_tree_paths(root): @@ -262,9 +265,20 @@ def sum_binary_tree_path_ints(root: stdlib.Node | None) -> int: return sum(path_ints) +def binary_tree_paths_as_lists( + paths: list[list[stdlib.BinaryTreeNode]], +) -> list[list[int]]: + paths_vals: list[list[int]] = [] + + for path in paths: + paths_vals.append([node.val for node in path]) + + return paths_vals + + def collect_binary_tree_paths( - node: stdlib.Node | None, -) -> typing.Iterator[list[stdlib.Node]]: + node: stdlib.BinaryTreeNode | None, +) -> typing.Iterator[list[stdlib.BinaryTreeNode]]: if node is None: return @@ -279,3 +293,29 @@ def collect_binary_tree_paths( if node.left is not None: for path in collect_binary_tree_paths(node.left): yield [node] + path + + +def binary_tree_from_list(inlist: list[int | None]) -> stdlib.BinaryTreeNode | None: + if len(inlist) == 0: + return None + + nodes: list[stdlib.BinaryTreeNode | None] = [ + typing.cast(stdlib.BinaryTreeNode | None, stdlib.TreeNode.from_int(i)) + for i in inlist + ] + nodes_copy = nodes[::-1] + root = nodes_copy.pop() + + for node in nodes: + if node is None: + continue + + if len(nodes_copy) == 0: + break + + node.left = nodes_copy.pop() + + if len(nodes_copy) > 0: + node.right = nodes_copy.pop() + + return root diff --git a/leetcode/test_stuff.py b/leetcode/test_stuff.py index 6954a65..042d5b1 100644 --- a/leetcode/test_stuff.py +++ b/leetcode/test_stuff.py @@ -144,7 +144,9 @@ def test_min_stack(ops: list[tuple[str] | tuple[str, int]], expected: list[int | ), ], ) -def test_sort_linked_list(head: stdlib.ListNode | None, expected: stdlib.ListNode | None): +def test_sort_linked_list( + head: stdlib.LinkedListNode | None, expected: stdlib.LinkedListNode | None +): if head is None: assert stuff.sort_linked_list(head) == expected return @@ -168,7 +170,7 @@ def test_sort_linked_list(head: stdlib.ListNode | None, expected: stdlib.ListNod ], ) def test_connect_binary_tree_right( - root: stdlib.Node | None, expected: list[int | None] | None + root: stdlib.ConnectableBinaryTreeNode | None, expected: list[int | None] | None ): if expected is None: assert root is None @@ -192,5 +194,38 @@ def test_connect_binary_tree_right( ), ], ) -def test_connect_binary_tree_sum_numbers(root: stdlib.Node | None, expected: int): +def test_connect_binary_tree_sum_numbers( + root: stdlib.BinaryTreeNode | None, expected: int +): assert stuff.sum_binary_tree_path_ints(root) == expected + + +@pytest.mark.parametrize( + ("inlist", "expected"), + [ + ( + [3, 5, 1, 6, 2, 0, 8, None, None, 7, 4], + stdlib.TreeNode( + 3, + left=stdlib.TreeNode( + 5, + left=stdlib.TreeNode(6), + right=stdlib.TreeNode( + 2, + left=stdlib.TreeNode(7), + right=stdlib.TreeNode(4), + ), + ), + right=stdlib.TreeNode( + 1, + left=stdlib.TreeNode(0), + right=stdlib.TreeNode(8), + ), + ), + ), + ], +) +def test_binary_tree_from_list( + inlist: list[int | None], expected: stdlib.BinaryTreeNode | None +): + assert stuff.binary_tree_from_list(inlist) == expected