class Solution:
def findDistance(self, root: TreeNode, p: int, q: int) -> int:
"""
tc O(N) sc O(H)
"""
### LCA of p, q
def dfs(node):
if not node or node.val in (q,p):
return node
left = dfs(node.left)
right = dfs(node.right)
if left and right :
return node
else:
return left or right
# get distance from lca to target node
def dis(node,target):
if not node:
return float('inf')
if node.val == target:
return 0
return 1 + min(dis(node.left,target),dis(node.right,target))
if p == q:
return 0 # early termination
lca = dfs(root)
return dis(lca,p) + dis(lca,q)