# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def maximumAverageSubtree(self, root: TreeNode) -> float:
        # time O(N) space O(H)
        """
            avg = sum(subtree) / # node, use post order traversal to get subtree sum + node ; base case: [0,0.0]. Be careful in order to prevent LTE, need to assign recursion result to temp variable at current level. 
        """
        self.max_avg = 0
        def post(node):
            if not node:return[0,0.0]
            # here need to use 
            n1, s1 = post(node.left)
            n2, s2 = post(node.right)
            num_node = n1+n2 + 1
            sum_val  =  s1 + s2 + node.val
            self.max_avg = max(self.max_avg,sum_val/num_node) # here need to be careful with using /  instead of //
            return [num_node,sum_val]
        post(root)
        return self.max_avg