class Solution:
def kthSmallest(self, mat: List[List[int]], k: int) -> int:
""" tc (m-1)*klgk sc O(K)
main idea: use k pairs with smalleset sum of two arrays to call m-1 times, there m = len(mat)
"""
def k_small_pairs(A,B,k):
res = []
pq = [] #(sum, A[i],B[i])
j = 0
n = len(A)
max_h = min(k,n)
for i in range(max_h):
heappush(pq,(A[i]+B[j],i,j))
cnt = 0
n = len(A)
while cnt < k and pq :
_sum, i, j = heappop(pq)
res.append(_sum)
cnt += 1
if j+1 < len(B):
j += 1
heappush(pq,(A[i]+B[j],i,j))
return res
base = mat[0]
m = len(mat)
for i in range(1,m):
base = k_small_pairs(base,mat[i],k)
return base[k-1]