# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
# time O(N) space O(H)
    def findTilt(self, root: TreeNode) -> int:
        self.total = 0
        def dfs(node):
            if not node:
                return 0
            lSum = dfs(node.left)
            rSum = dfs(node.right)
            self.total += abs(lSum-rSum)
            return lSum + rSum + node.val
        dfs(root)
        return self.total