Top Down

class Solution:
    def findAllConcatenatedWordsInADict(self, words: List[str]) -> List[str]:
        """tc O(N*W^2*2^W)  sc O(N*W^2*2^W)?? O(N*W^2)   W:  avg len of word
        main idea: dfs + memo 
        1. create a global memo, put all words into set for O(1) look up
        2. loop through all the words and check if it can be concatenated by at least two words
        2.1. assume any point of current word as split point, break into prefix and suffix
        2.2. check (1)if prefix and suffix are composed of two words in the word set  (2) check if prefix is a single word and remaining suffix composed of more than one word in word set  =>  if any of the two condition is meet, we consider current word can be concatenated,  otherwise not.  Cache the result into global memo for future pruning  
        
        """
        memo = {}
        words_set = set(words)
        def dfs(s, memo):
            if s in memo:
                return memo[s]
            memo[s] = False 
            for i in range(1,len(s)): # start from 1 because we want to skip cases where s[:0] == ""  
                prefix = s[:i]
                suffix = s[i:]
                if prefix in words_set and suffix in words_set:
                    memo[s] = True 
                    break 
                if prefix in words_set and dfs(s[i:],memo):
                    memo[s] = True
                    break
                # there is no need to check suffix since prefix already covered 
            return memo[s]
        res = []
        for w in words_set:
            if dfs(w,memo):
                res.append(w)
        return res 

Optimization: Bottom up —-

class Solution:
    def findAllConcatenatedWordsInADict(self, words: List[str]) -> List[str]:
        """tc O(N*W^3)  sc O(N+ max(W))   W:  avg len of word
        1. sort by length,  
        2. dp to find if current s can be split into more than one possible, put it into result
            2.1 find split point start from  max(1, min_len of word) to n +1 
            2.2 if we can find substring w[j:i] in the seen set and word w[:j] is valid, means we found the word can be concatenated  
        3. add current into seen 
        """
        # dp[i]: word[:i] can be concatenated with words
        words.sort(key=len)
        min_len = min(1,len(words[0]))
        memo = {}
        words_set = set()
        def bottom_up(s):
            if not s : return False # edge case:  s == "" 
            n = len(s)
            dp = [False]*(n+1)
            dp[0] = True 
            for i in range(min_len,n+1): # skip cases suffix length smaller than min word length
                for j in range(i):
                    if dp[j] and s[j:i] in words_set: # tc O(W)  => to optimize, use Trie, see solution below
                        dp[i] = True
                        break
            words_set.add(s)
            return dp[-1]
            
        res = []
        for w in words:
            if bottom_up(w):
                res.append(w)
        return res 

Trie + DP solution


class TrieNode :
    def __init__(self):
        self.children = {}
        self.isEnd = False


class Trie:
    def __init__(self):
        self.root = TrieNode()
    
    def add(self, word):
        cur = self.root
        for c in word:
            if c not in cur.children:
                cur.children[c] = TrieNode()
            cur = cur.children[c]
        cur.isEnd = True
        cur.word = word
    
    def is_concatenated(self, word):
        n = len(word)
        dp = [-1] * (n+1) 
        dp[n] = 0
        cur = self.root 
        for i in range(n-1,-1,-1):
            cur = self.root 
            for j in range(i,n):
                if not word[j] :  continue # for "" string
                if word[j] in cur.children:
                    cur = cur.children[word[j]]
                else:
                    break
                if cur.isEnd and dp[j+1] != -1: # current index has reached to a word's end and its neighbour substring starting with j + 1 is also a word  
                    dp[i] = max(dp[i],dp[j+1] +1) # avoid overwirting 
        return dp[0] >= 2
         
    
class Solution:
    def findAllConcatenatedWordsInADict(self, words: List[str]) -> List[str]:
        """ tc O(N*W^2) sc O(N*W)
        1. create trie from words 
        2. from back to front, check each word if it is concatenated 
        """
        root = Trie()
        for w in words:
            root.add(w)
            
        res = []
        for w in words:
            if root.is_concatenated(w):
                res.append(w)
                
        return res