class Solution:
    def kthSmallest(self, mat: List[List[int]], k: int) -> int:
        """ tc (m-1)*klgk  sc O(K)
        main idea:  use k pairs with smalleset sum of two arrays  to call m-1 times, there m = len(mat)
        
        """
        
        def k_small_pairs(A,B,k):
            res = []
            pq = [] #(sum, A[i],B[i])
            j = 0
            n = len(A)
            max_h = min(k,n)
            for i in range(max_h):
                heappush(pq,(A[i]+B[j],i,j))
            cnt = 0
            n = len(A)
            while cnt < k and pq :
                _sum, i, j  = heappop(pq)
                res.append(_sum)
                cnt += 1
                if j+1 < len(B):
                    j += 1
                    heappush(pq,(A[i]+B[j],i,j))
            return res 
        
        base = mat[0]
        m = len(mat)
        for i in range(1,m):
            base = k_small_pairs(base,mat[i],k)
        return base[k-1]