class Solution:
    def findMaxLength(self, nums: List[int]) -> int:
        # map   idx: cnt, only record first time cnt  appear with earliest idx
        # edge case: if from beginning the subarrary is already 0/1 switch changing (cnt= 0), max_len = current idx - (-1)  
        d = {0:-1}
        n= len(nums)
        max_len = 0
        cnt  = 0 # num == 1 , cnt += 1 num == 0 cnt -=1 
        for i in range(n):
            if nums[i] == 1:
                cnt += 1
            else:
                cnt -= 1 
            if cnt not in d:
                d[cnt] = i 
            else:
                max_len = max(max_len,i-d[cnt])
        return max_len