class Solution:
    def findDistance(self, root: TreeNode, p: int, q: int) -> int:
        """
        tc O(N)  sc O(H) 
        """
        ### LCA of p, q 
        def dfs(node):
            if not node or node.val in (q,p):
                return node
            left = dfs(node.left)
            right = dfs(node.right)
            if left and right :
                return node 
            else:
                return left or right 

        # get distance from lca to target node 
        def dis(node,target):
            if not node:
                return float('inf')
            if node.val == target:
                return 0
            return 1 + min(dis(node.left,target),dis(node.right,target))
        
        if p == q:
            return 0  # early termination 
        lca = dfs(root)
        return dis(lca,p) + dis(lca,q)