# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    """
    main idea: there 3 cases for deletion. (1)both left kid and right child is None (1) either child is None (3) both children are not None. there are 3 conditions to determ: key > root.val; key < root.val; key == root.val. be careful not mess up the logic.
    step1, check if root exist
    step2, compare key with root.val, if root.val < key, means target node is on root's right subtree. so recursion to root.right. vice versa 
    step3, discuss three casses. for Case (1) just return; Case (2) replace root with root.left or root.right depends on either of them exists; Case (3) can choose right subtree's mininum number to replace with root.val and delete minimun number in right subtree 
    """
#time O(lgN) space O(H) for the recursion 
    def deleteNode(self, root: TreeNode, key: int) -> TreeNode:
        if not root:
            return
        if key > root.val:
            root.right = self.deleteNode(root.right, key)
        elif key < root.val:
            root.left = self.deleteNode(root.left, key)
            
        #not the key is root of a subtree
        else:
            # case 1: No child
            if not root.left and not root.right:
                root.val = None
                return
            # case 2 : One child
            elif not root.left:
                tmp = root
                root = root.right
                tmp.right = None
            elif not root.right:
                tmp = root
                root = root.left
                tmp.left = None
                #Case 3: Two Children
            else:
                # find max val in left sub tree
                tmp = root.left
                while tmp.right:
                    tmp = tmp.right
                # replace root val with max in the left subtree
                root.val = tmp.val
                #tmp.right = root.right
                #return root.left
                root.left = self.deleteNode(root.left,tmp.val)
                
        return root