class Solution:
MIN_VAL = -4 * (10**4)
MAX_VAL = 4 * (10**4)
def __init__(self):
self.max_sum = 0
def maxSumBST(self, root: Optional[TreeNode]) -> int:
self.max_sum = 0
self._travel(root)
return self.max_sum
def _travel(self, node):
"""
node: TreeNode root to check
return:
tuple(Bool, Number, Number, Number)
is_bst, tree_min, tree_max, tree_sum
"""
if not node:
return True, self.MAX_VAL, self.MIN_VAL, 0
is_left_bst, left_min, left_max, left_sum = self._travel(node.left)
is_right_bst, right_min, right_max, right_sum = self._travel(node.right)
if is_left_bst and is_right_bst \
and node.val > left_max \
and node.val < right_min:
tree_sum = node.val + left_sum + right_sum
result = (
True,
min(left_min, node.val),
max(right_max, node.val),
tree_sum
)
self.max_sum = max(self.max_sum, tree_sum)
return result
return False, None, None, 0