class Solution:
def findMaxLength(self, nums: List[int]) -> int:
# map idx: cnt, only record first time cnt appear with earliest idx
# edge case: if from beginning the subarrary is already 0/1 switch changing (cnt= 0), max_len = current idx - (-1)
d = {0:-1}
n= len(nums)
max_len = 0
cnt = 0 # num == 1 , cnt += 1 num == 0 cnt -=1
for i in range(n):
if nums[i] == 1:
cnt += 1
else:
cnt -= 1
if cnt not in d:
d[cnt] = i
else:
max_len = max(max_len,i-d[cnt])
return max_len