Heap Solution

class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        # time O(NlgN) space O(1)
        n = len(matrix)
        lo = matrix[0][0]
        hi = matrix[-1][-1]
        cnt,candi = 0,0
        while lo < hi:
            mid = lo + (hi-lo)//2
            
            cnt,candi = self.cnt_less_equal(matrix,mid)
            #print(cnt,candi,"aaaaaaaaaa")
                # before mid, # of valuse < k  
            if cnt == k:
                return candi
            elif  cnt < k :
                lo = mid+1
            else:
                hi = mid
            #print(mid,cnt)
        return lo
    
    def cnt_less_equal(self,matrix,val):
        n = len(matrix)
        cnt = 0
        max_val = -1
        i,j = 0,n-1
        while i < n and j >= 0:
            if matrix[i][j] <= val:
                max_val = max(max_val,matrix[i][j])
                cnt += j+ 1 
                i += 1 
            else:
                j -= 1
        return (cnt, max_val)

Binary Search Solution

class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        # time O(N) space O(1)
        n = len(matrix)
        lo = matrix[0][0]
        hi = matrix[-1][-1]
        cnt,candi = 0,0
        while lo < hi:
            mid = lo + (hi-lo)//2
            
            cnt,candi = self.cnt_less_equal(matrix,mid)
            #print(cnt,candi,"aaaaaaaaaa")
                # before mid, # of valuse < k  
            if cnt == k:
                return candi
            elif  cnt < k :
                lo = mid+1
            else:
                hi = mid
            #print(mid,cnt)
        return lo
    
    def cnt_less_equal(self,matrix,val):
        n = len(matrix)
        cnt = 0
        max_val = -1
        i,j = 0,n-1
        while i < n and j >= 0:
            if matrix[i][j] <= val:
                max_val = max(max_val,matrix[i][j])
                cnt += j+ 1 
                i += 1 
            else:
                j -= 1
        return (cnt, max_val)