Main idea: current_presum - old_presum = target , where old_presum is in hmap, so we can check if current_presum - target in hmap. if yes, res + 1; then update current_presum in hmap before entereing subtree. Note, each time when current node and its subtrees are traversed, hmap need to be reset to state where current node is not chosen since other nodes at same level will not get affected by current node.
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
# time O(N) space O(N)
class Solution:
def pathSum(self, root: TreeNode, sum: int) -> int:
# step1: early temination
if not root:
return 0
# step2: initialize hmap
psum_freq = {0:1}
# step3: dfs API, pass node, currentSum ,hmap, target sum
def dfs(node,cur_sum,target,psum_freq):
# step4 base case
if not node:
return 0
#step5 update prefix sum at current node
cur_sum += node.val
#step6, get number of valid path, ended by current node
num_path_to_cur = psum_freq.get(cur_sum -target,0)
#step7, update hmap
psum_freq[cur_sum] = psum_freq.get(cur_sum,0) +1
# step8: add 3 path together
res = num_path_to_cur + dfs(node.left,cur_sum,target,psum_freq) + dfs(node.right,cur_sum,target,psum_freq)
# step9 restore map, since recursion goes from bottom to top
psum_freq[cur_sum] -= 1
return res # return total cnt of valid path in the subtree rooted with current node
return dfs(root,0,sum,psum_freq)
Brute force solution time O(N^2) space O(H)
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def pathSum(self, root: TreeNode, sum: int) -> int:
# set global cnt
self.numPath = 0
# define a function to check if current node is part of path, how many ways to sum to target
def test(node,target):
if not node:
return
if node.val == target:
self.numPath += 1
if node.left:
test(node.left,target-node.val)
if node.right:
test(node.right,target-node.val)
def dfs(node, target):
if not node:
return
test(node,target)
if node.left:
dfs(node.left,target)
if node.right:
dfs(node.right,target)
dfs(root,sum)
return self.numPath
simpler code
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def pathSum(self, root: TreeNode, sum: int) -> int:
# start with node find if there is child node or itself where sum of value == sum ,if found ,return 1 else return 0
def dfs(node,target):
if not node:
return 0
# get number of ways to make sum of path to target.
return (1 if node.val == target else 0) + dfs(node.left,target-node.val) + dfs(node.right,target-node.val)
# early termination
if not root:
return 0
# find path of subtree start with root or start with left child or start with right child
return dfs(root,sum) +self.pathSum(root.left,sum) + self.pathSum(root.right,sum)