class UnionFind():
    def __init__(self, n):
        self.root = [i for i in range(n)]
        self.sz = [1] * n
        self.n_components = n
    
    def find(self, a):
        root = a
        while root != self.root[root]:
            root = self.root[root]
        
        while a != root:
            nxt = self.root[a]
            self.root[a] = root
            a = nxt
        
        return root
    
    def union(self, a, b):
        root1 = self.find(a)
        root2 = self.find(b)
        if root1 == root2:
            return root1
        
        if self.sz[root1] > self.sz[root2]:
            self.root[root2] = root1
            self.sz[root1] += self.sz[root2]
        else:
            self.root[root1] = root2
            self.sz[root2] += self.sz[root1]
        
        self.n_components -= 1

Problems

2709. Greatest Common Divisor Traversal

1202. Smallest String With Swaps

399. Evaluate Division