Alpha Beta

Starter code for minimax and alpha-beta pruning.
import unittest

class SearchTree:
    def __init__(self, depth, branch_factor=2, fill_value = None):
        self.depth = depth
        self.branch = branch_factor
        self.data = [fill_value] * self.validate_index(depth + 1,0)

    def validate_index(self, d, n):
        # base = sum of geometric series (b^k) for k=0 to k=d-1
        base = sum(self.branch ** k for k in range(d))

        if n < 0:
            raise ValueError(f"get_node: node number {n} must be positive")
        elif n >= self.branch ** d:
            raise ValueError(
                f"get_node: node number {n} must be less than (branch {self.branch}) ** (depth {d})"
            )

        return base + n

    def get_node(self, d, n):
        return self.data[self.validate_index(d, n)]

    def set_node(self, d, n, val):
        self.data[self.validate_index(d, n)] = val

    def get_child_n(self, d, n):
        if d >= self.depth:
            return []

        return [self.branch * n + k for k in range(self.branch)]

    def get_parent_n(self, d, n):
        '''Do not call this with an invalid depth. When depth==0 it returns None. Maybe it should error instead.
        This function should return a Maybe Int. In Python that means use None. Another choice is to return -1.'''
        if d == 0:
            return None

        return n  // self.branch

class TestSeachTree(unittest.TestCase):
    def test_tree_setup(self):
        t = SearchTree(2, branch_factor=3, fill_value=0)
        self.assertEquals(t.data, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

    def test_validate_index(self):
        t = SearchTree(2, 7)
        self.assertEquals(t.validate_index(0, 0), 0)
        self.assertEquals(t.validate_index(1, 0), 1)
        self.assertEquals(t.validate_index(1, 6), 7)
        self.assertEquals(t.validate_index(2, 0), 8)


    def test_basic_indexing(self):
        t = SearchTree(depth=2, branch_factor=3)
        t.set_node(0, 0, 100)
        t.set_node(1, 0, 200)
        t.set_node(1, 1, 201)
        t.set_node(1, 2, 202)
        for n in range(9):
            t.set_node(2, n, 300 + n)

        self.assertEquals(
            t.data, [100, 200, 201, 202, 300, 301, 302, 303, 304, 305, 306, 307, 308]
        )

    def test_children(self):
        t = SearchTree(depth=2, branch_factor=5)
        self.assertEquals(t.get_child_n(1, 3), [15, 16, 17, 18, 19])

def is_maximizing(depth: int):
    return depth % 2 == 0