class Solution:
    def maxSum(self, A: List[int], B: List[int]) -> int:
        """
        tc O(N+M) sc O(1)
        [wrong thought] I thought it will be a dfs + memo get max accumutive sum, but since there are only two choices at each common node, we can use two ptr to simplify the problem.

        main idea: common elements have to been gone through ==> each time a,b pointer move forward, choose smaller value  to garanteen common elements exist in both path
        when at common bridge node, compare both accumutive sum in two direction, choose the bigger one 
        
        """
        M = 10**9 + 7
        i,j,na,nb = 0,0,len(A),len(B)
        sum_a,sum_b = 0,0
        while i< na or j < nb:
            if i < na and (j == nb or A[i] < B[j]): # for A[i], when j either reach to the end or B[j] > A[j] 
                sum_a += A[i]
                i += 1 
            elif j < nb and (i==na or A[i] > B[j]):
                sum_b += B[j]
                j += 1
            else:  # A[i] == B[j]
                sum_a = sum_b = max(sum_b,sum_a) + A[i]
                i += 1
                j += 1 
        return max(sum_a,sum_b) %  M # or b?

DFS solution

class Solution:
    def maxSum(self, A: List[int], B: List[int]) -> int:
        """
        tc O(N*M) sc O(M*N)
        main idea: common elements have to been gone through ==> each time a,b pointer move forward, choose smaller value  to garanteen common elements exist in both path
        when at common bridge node, compare both accumutive sum in two direction, choose the bigger one 
        
        """
        M = 10**9 + 7
        d = defaultdict(list)
        for i in range(len(A)-1):
            d[A[i]].append(A[i+1])
            
        for j in range(len(B)-1):
            d[B[j]].append(B[j+1])
            
        memo = {}
        def dfs(cur):
            if cur in memo:
                return memo[cur]
            if cur not in d:
                return cur
            max_sum = 0
            for nxt in d[cur]:
                max_sum = max(max_sum,dfs(nxt)) # no dfs(nxt)%M 
            max_sum += cur
            memo[cur] = max_sum
            return max_sum
        return max(dfs(A[0])%M,dfs(B[0])%M)