# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
"""
main idea: there 3 cases for deletion. (1)both left kid and right child is None (1) either child is None (3) both children are not None. there are 3 conditions to determ: key > root.val; key < root.val; key == root.val. be careful not mess up the logic.
step1, check if root exist
step2, compare key with root.val, if root.val < key, means target node is on root's right subtree. so recursion to root.right. vice versa
step3, discuss three casses. for Case (1) just return; Case (2) replace root with root.left or root.right depends on either of them exists; Case (3) can choose right subtree's mininum number to replace with root.val and delete minimun number in right subtree
"""
#time O(lgN) space O(H) for the recursion
def deleteNode(self, root: TreeNode, key: int) -> TreeNode:
if not root:
return
if key > root.val:
root.right = self.deleteNode(root.right, key)
elif key < root.val:
root.left = self.deleteNode(root.left, key)
#not the key is root of a subtree
else:
# case 1: No child
if not root.left and not root.right:
root.val = None
return
# case 2 : One child
elif not root.left:
tmp = root
root = root.right
tmp.right = None
elif not root.right:
tmp = root
root = root.left
tmp.left = None
#Case 3: Two Children
else:
# find max val in left sub tree
tmp = root.left
while tmp.right:
tmp = tmp.right
# replace root val with max in the left subtree
root.val = tmp.val
#tmp.right = root.right
#return root.left
root.left = self.deleteNode(root.left,tmp.val)
return root