class Solution:
def countTriplets(self, A: List[int]) -> int:
"""
tc O(N^2) sc O(1) prefix XOR
assuming a = A[i] ^ A[i+1] ^...A[j-1], b = A[j]^A[j+1]^...A[k]
main idea3: - a == b => a^a == a^b => A[i]^A[i+1]^...A[k] == 0 => prefix[k+1] = prefix[i]
- find out how many pair(i,k) of prefix value are equal
1. calculate prefix array
2. brute force count the pair (i,k) => since i < j <=k => res += j-i
"""
res,n = 0,len(A)
for i in range(n):
xor = A[i]
for j in range(i+1,n):
xor ^= A[j]
if xor == 0:
res += j-i
return res
O(N)
class Solution:
def countTriplets(self, A: List[int]) -> int:
"""
tc O(N) sc O(N) prefix XOR
=> find sum of index distance with A[i] = A[j]
- count the frequency and total value at the same time
main idea: given array A[0],A[1],A[2],...A[n-1],
assuming at index j, if xor(A[0:i]) inclusive has appeared before, then sor(A[j+1:i]) = 0
if xor(A[i:j]) = 0, then A[i:j] will add (j-i-1) to the answer
assuming xor(A[0:i]) = x, and x hass occurred 3 time before at index i1,i2,i3
answer for i will be = (i-i1-1)+(i-i2-1)+(i-i3-1) => f*i-(i1+i2+i3) - f = f*(i-1) -(i1+i2+i3); where f : number of times x has occurred previously; (i1+i2+i3): sum of index where x has occurred
"""
A.insert(0,0)
res,n = 0,len(A)
pxor = [0] *n
for i in range(1,n):
pxor[i] = pxor[i-1]^A[i]
cnt,total = defaultdict(int),defaultdict(int)
for j in range(n):
res += cnt[pxor[j]]*(j-1)-total[pxor[j]]
cnt[pxor[j]] += 1 # cnt # same prefix xor
total[pxor[j]] += j # sum of index when prefix[j] == prefix[i]
return res