class RandomizedCollection:
    """
    1. hashmap to store val_ idx set mapping; list to store duplicable values 
    2. when remove, find any idx with val in the list, swap with last element, update swapped remove last in the list, pop
    3. get random, to get random choice in the list 
    """
    def __init__(self):
        self.val_idx = collections.defaultdict(set)
        self.li = []
    def insert(self, val: int) -> bool:
        """  time O(1) space O(N)
        1. add to list 
        2. get idx
        3. update dict
        """
        self.val_idx[val].add(len(self.li))
        self.li.append(val) 
        return len(self.val_idx[val]) == 1 
        

    def remove(self, val: int) -> bool:
        """
        1. save and remove last idx of val in the dict 
        2. find last val, and  swap idx and last val 
        3. remove last val's 
        time O(1)  space O(N)
        """
        #print(self.val_idx)
        if not self.val_idx[val]: # to avod case where there is empty set for this val key 
            return False 

        idx = self.val_idx[val].pop()
        last = self.li[-1]

        self.li[idx] = last #= self.li[-1],self.li[idx]
        self.val_idx[last].add(idx)
        self.val_idx[last].discard(len(self.li)-1) # remove val in set by value O(1) on avg
        self.li.pop()
        return True 

        

    def getRandom(self) -> int:
        """
        Get a random element from the collection.
        """
        #print(self.li)
        return random.choice(self.li)  # O(1)

about time complexity of random.choice: refer time complexity of set.discard(val) refer

traditional solution

class RandomizedCollection:
    """
        main idea: in main array, record not only val itself, also its index in hashtables 
    """
    def __init__(self):
        self.d = collections.defaultdict(list) #val : [idx1_in_arr,idx2,...]
        self.arr = [] # [val,pos_in_d]

    def insert(self, val: int) -> bool:
        
        self.d[val].append(len(self.arr))
        self.arr.append([val,len(self.d[val])-1])
        return len(self.d[val]) == 1 
    
    def remove(self, val: int) -> bool:
        if  not self.d[val]: # don't use if val not in self.d  !!
            return False 

        idx = self.d[val][-1] # there may be cases last item in d[val] == len(arr)-1
        last,last_d_pos = self.arr[-1][0],self.arr[-1][1]
        self.arr[idx] = [last,last_d_pos]
        self.d[last][last_d_pos] = idx
        self.d[val].pop()
        self.arr.pop()
        return True 
    
    def getRandom(self) -> int:
        idx = random.randint(0,len(self.arr)-1)
        return self.arr[idx][0]