# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def pruneTree(self, root: TreeNode) -> TreeNode:
        """ tc O(N) sc O(h) h ~ N 
        main idea: post order, if left/ right child return False, set to None.
                    1. if left and left.val == 
                    
        """
        def pst(node):
            if not node:
                return 
            left = right = None
            left = pst(node.left)
            right = pst(node.right)
            #left 
            if left == False:
                node.left = None
            # right 
            if right == False:
                node.right = None
            # no children  
            if not node.left and not node.right:
                return node.val == 1
            # with children
            if left or right:
                return True 
        return root if pst(root) != False else None 
    
    """
    test case
    [0,null,0,1,1,null,1,null,1,null,null,null,null]     [0,null,0,1,1,null,1,null,1]
    [0,null,0,0,0]                                       []
    [1,0,1,0,0,0,1]                                      [1,null,1,null,1]
    
    """

optimization


# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def pruneTree(self, root: TreeNode) -> TreeNode:
        if not root:
            return None
        root.left = self.pruneTree(root.left)
        root.right = self.pruneTree(root.right)
        if root.val == 0 and root.left == None and root.right == None: root = None
        return root