class Solution:
    def lowestCommonAncestor(self, root: 'TreeNode', nodes: 'List[TreeNode]') -> 'TreeNode':
        """ tc O(N)  sc O(N)
        """
        nodes = set(nodes)
        def dfs(node):
            # return condition
            if not node:
                return node
            if node in nodes:
                return node 
            leftNodes = dfs(node.left)
            rightNodes = dfs(node.right)
            if leftNodes and rightNodes:
                return node 
            elif leftNodes:
                return leftNodes
            elif rightNodes:
                return rightNodes
        return dfs(root)

another solution: compare cnt

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution:
    def lowestCommonAncestor(self, root: 'TreeNode', nodes: 'List[TreeNode]') -> 'TreeNode':
        """ tc O(N)  sc O(N)
        similar to LC1644
        need to traverse until visit all the nodes and found number of visited nodes equal size of nodes
        
        """
        self.LCA = None
        nodes = set(nodes)
        def dfs(node):
            # return condition
            if not node:
                return 0#node
 
            leftNodes = dfs(node.left)
            rightNodes = dfs(node.right)
            cnt = leftNodes + rightNodes
            if node in nodes:  # slower if put above 
                cnt += 1
            if cnt == len(nodes) and not self.LCA:  # first time found the node 
                self.LCA = node
            return cnt 
        dfs(root)
        return self.LCA