class Solution:
def lowestCommonAncestor(self, root: 'TreeNode', nodes: 'List[TreeNode]') -> 'TreeNode':
""" tc O(N) sc O(N)
"""
nodes = set(nodes)
def dfs(node):
# return condition
if not node:
return node
if node in nodes:
return node
leftNodes = dfs(node.left)
rightNodes = dfs(node.right)
if leftNodes and rightNodes:
return node
elif leftNodes:
return leftNodes
elif rightNodes:
return rightNodes
return dfs(root)
another solution: compare cnt
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None
class Solution:
def lowestCommonAncestor(self, root: 'TreeNode', nodes: 'List[TreeNode]') -> 'TreeNode':
""" tc O(N) sc O(N)
similar to LC1644
need to traverse until visit all the nodes and found number of visited nodes equal size of nodes
"""
self.LCA = None
nodes = set(nodes)
def dfs(node):
# return condition
if not node:
return 0#node
leftNodes = dfs(node.left)
rightNodes = dfs(node.right)
cnt = leftNodes + rightNodes
if node in nodes: # slower if put above
cnt += 1
if cnt == len(nodes) and not self.LCA: # first time found the node
self.LCA = node
return cnt
dfs(root)
return self.LCA