"""
Algorithms around set operations
"""
import numpy as np
import tfields
[docs]class UnionFind(object):
"""
Source:
http://code.activestate.com/recipes/215912-union-find-data-structure/
This algorithm and data structure are primarily used for Kruskal's Minimum
Spanning Tree algorithm for graphs, but other uses have been found.
The Union Find data structure is not a universal set implementation, but can
tell you if two objects are in the same set, in different sets, or you can
combine two sets.
ufset.find(obja) == ufset.find(objb)
ufset.find(obja) != ufset.find(objb)
ufset.union(obja, objb)
"""
def __init__(self):
"""
Create an empty union find data structure.
"""
self.num_weights = {}
self.parent_pointers = {}
self.num_to_objects = {}
self.objects_to_num = {}
self.__repr__ = self.__str__
[docs] def insert_objects(self, objects):
"""
Insert a sequence of objects into the structure. All must be Python hashable.
"""
for obj in objects:
self.find(obj)
[docs] def find(self, obj):
"""
Find the root of the set that an object 'obj' is in.
If the object was not known, will make it known, and it becomes its own set.
Object must be Python hashable.'''
"""
if obj not in self.objects_to_num:
obj_num = len(self.objects_to_num)
self.num_weights[obj_num] = 1
self.objects_to_num[obj] = obj_num
self.num_to_objects[obj_num] = obj
self.parent_pointers[obj_num] = obj_num
return obj
stk = [self.objects_to_num[obj]]
par = self.parent_pointers[stk[-1]]
while par != stk[-1]:
stk.append(par)
par = self.parent_pointers[par]
for i in stk:
self.parent_pointers[i] = par
return self.num_to_objects[par]
[docs] def union(self, object1, object2):
"""
Combine the sets that contain the two objects given.
Both objects must be Python hashable.
If either or both objects are unknown, will make them known, and combine them.
"""
o1p = self.find(object1)
o2p = self.find(object2)
if o1p != o2p:
on1 = self.objects_to_num[o1p]
on2 = self.objects_to_num[o2p]
w1 = self.num_weights[on1]
w2 = self.num_weights[on2]
if w1 < w2:
o1p, o2p, on1, on2, w1, w2 = o2p, o1p, on2, on1, w2, w1
self.num_weights[on1] = w1 + w2
del self.num_weights[on2]
self.parent_pointers[on2] = on1
def __str__(self):
"""
Included for testing purposes only.
All information needed from the union find data structure can be attained
using find.
"""
sets = {}
for i in range(len(self.objects_to_num)):
sets[i] = []
for i in self.objects_to_num:
sets[self.objects_to_num[self.find(i)]].append(i)
out = []
for i in sets.itervalues():
if i:
out.append(repr(i))
return ", ".join(out)
def __call__(self, iterator):
"""
Build unions for whole iterators of any size
"""
self.insert_objects(tfields.lib.util.flatten(iterator))
i = 0
for item in iterator:
for i1, i2 in tfields.lib.util.pairwise(item):
self.union(i1, i2)
i += 1
[docs] def groups(self, iterator):
"""
Return full groups from iterator
"""
groups = {}
keys = []
for item in iterator:
key = self.find(item[0])
if key not in keys:
keys.append(key)
if key not in groups:
groups[key] = []
groups[key].append(item)
return [groups[k] for k in keys]
[docs] def group_indices(self, iterator):
"""
Return full groups from iterator
"""
group_indices = {}
keys = []
for i, item in enumerate(iterator):
key = self.find(item[0])
if key not in keys:
keys.append(key)
if key not in group_indices:
group_indices[key] = []
group_indices[key].append(i)
return [group_indices[k] for k in keys]
[docs]def disjoint_groups(iterator):
"""
Disjoint groups implementation
Examples:
>>> import tfields
>>> tfields.lib.sets.disjoint_groups([[0, 0, 0, 'A'], [1, 2, 3], [3, 0],
... [4, 4, 4], [5, 4], ['X', 0.42]])
[[[0, 0, 0, 'A'], [1, 2, 3], [3, 0]], [[4, 4, 4], [5, 4]], [['X', 0.42]]]
>>> tfields.lib.sets.disjoint_groups([[0], [1], [2], [3], [0, 1], [1, 2], [3, 0]])
[[[0], [1], [2], [3], [0, 1], [1, 2], [3, 0]]]
Returns:
list: iterator items grouped in disjoint sets
"""
uf = UnionFind()
uf(iterator)
return uf.groups(iterator)
[docs]def disjoint_group_indices(iterator):
"""
Examples:
>>> import tfields
>>> tfields.lib.sets.disjoint_group_indices([[0, 0, 0, 'A'], [1, 2, 3],
... [3, 0], [4, 4, 4], [5, 4], ['X', 0.42]])
[[0, 1, 2], [3, 4], [5]]
>>> tfields.lib.sets.disjoint_group_indices([[0], [1], [2], [3], [0, 1], [1, 2], [3, 0]])
[[0, 1, 2, 3, 4, 5, 6]]
Returns:
list: indices of iterator items grouped in disjoint sets
"""
uf = UnionFind()
if isinstance(iterator, np.ndarray):
iterator = iterator.tolist()
uf(iterator)
return uf.group_indices(iterator)
[docs]def remap(
arr: np.ndarray, keys: np.ndarray, values: np.ndarray, inplace=False
) -> np.ndarray:
"""
Given an input array, remap its entries corresponding to 'keys' to 'values'
Args:
input: array to remap
keys: values to be replaced
values : values to replace 'keys' with
Returns:
output:
like 'input', but with elements remapped according to the mapping
defined by 'keys' and 'values'
"""
assert arr.dtype == int
assert arr.min() >= 0
"""
Assuming the values are between 0 and some maximum integer,
one could implement a fast replace by using the numpy-array as int->int dict, like below
"""
if not inplace:
arr = arr.copy()
mp = np.arange(0, arr.max() + 1)
mp[keys] = values
arr.ravel()[:] = mp[arr.ravel()]
return arr
if __name__ == "__main__":
import doctest
doctest.testmod()