# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def diameterOfBinaryTree(self, root: TreeNode) -> int:
        self.max  = 0
        if not root:
            return 0
        def dfs(node):
            if not node:
                return 0
            L = dfs(node.left)
            R = dfs(node.right)
            # update self.max 
            self.max = max(self.max, L+R+1) # L+R+1 is left subtree + right subtree + parent*2
            
            return max(L,R) + 1 
        dfs(root) 
        return self.max -1 # hence here need to - 1