[HL-29] Algorithm - Union Find
Context and problem
This is used to partition a set of elements into a number of disjoint sets.
For example,
# a set of elements
[1, 2, 3, 4, 5, 6, 7, 8]
# some elements are connected
(1,2)
(2,5)
(2,6)
(8,6)
(3,4)
# partitioned into disjoint subsets
[[1,2,5,6,8], [3,4], [7]]
Solution
We can use union-find to solve this problem. Basically, we create a forest (a set of trees), which each tree is a disjoint set.
It has two operations: union
and find
.
find
returns the root of current element, which follows the chain of parent pointers until reaching the root element.union
connects two elements if their roots are different.
- Initially, every element’s parent is itself.
- When two elements
are connected, we use
find
to figure out their roots at first. - If their roots are different, we merge two roots to make them become a single tree.
Issues and considerations
- Height of trees can grow as
O(n)
. - Unbalanced tree increases the time of
find
. - Two ways: merging by height (rank) or size of trees.
When to use this algorithm
This algorithm is used for finding disjoint sets. Even if the problem is defined in 2D, we can use an array to represent the data and apply this algorithm.
Example
- Leetcode
#200
# ruby version
def find(a, x)
pts = x
while a[pts] != pts
pts = a[pts]
end
return pts
end
def generate_input(test_data)
test_data.split('\n').map { |item| item.split('') }
end
# @param {Character[][]} grid
# @return {Integer}
def num_islands(grid)
if grid.nil? or grid.length == 0 or grid[0].length == 0
return 0
end
m = grid.length
n = grid[0].length
count = 0
a = []
m.times do |i|
n.times do |j|
if grid[i][j] == '1'
a.push(i*n + j)
count += 1
else
a.push(-1)
end
end
end
moves = [[0, 1], [1, 0], [0, -1], [-1, 0]]
m.times do |i|
n.times do |j|
if grid[i][j] == '0'
next
end
moves.each do |move|
x = i + move[0]
y = j + move[1]
if x >= 0 and x < m and y >= 0 and y < n and grid[x][y] == '1'
root_x = find(a, i * n + j)
root_y = find(a, x * n + y)
if root_x != root_y
a[root_x] = root_y
count -= 1
end
end
end
end
end
count
end
# test case
test_data = '11110\n' \
'11010\n' \
'11000\n' \
'00000'
grid = generate_input(test_data)
puts num_islands(grid) == 1