A Step by Step Approximate Nearest Neighbor Example In Python From Scratch

Nearest neighbor search is a fundamental problem for vector models in machine learning. It involves finding the closest vectors to a given query vector, and has many applications, including facial recognition, natural language processing, recommendation systems, image retrieval, and anomaly detection. However, for large datasets, exact nearest neighbor (k-NN) search can be prohibitively slow. That’s where approximate nearest neighbor (ANN) search comes in. ANN algorithms trade off some degree of accuracy for faster query times, making them a useful tool for large-scale data analysis. In this blog post, we’ll provide a gentle introduction to ANN in Python, covering the math behind the algorithm as well as a Python implementation from scratch. By the end of this post, you’ll have a solid understanding of how to perform approximate nearest neighbor search and billion-scale fast vector similarity search in milliseconds!

Assorted-color Wall Paint House Photo by pexels

Vlog

You can either continue to follow this tutorial or watch the following video. They both cover the math behind approximate nearest neighbor and its implementation in python from scratch.


πŸ™‹β€β™‚οΈ You may consider to enroll my top-rated machine learning course on Udemy

Decision Trees for Machine Learning

Popular Approximate Nearest Neighbour Libraries

There are several popular libraries available for approximate nearest neighbor search in Python, including Spotify’s Annoy and Facebook’s Faiss. These libraries implement state-of-the-art algorithms for approximate nearest neighbor search, making it easy for developers to perform efficient and scalable nearest neighbor search without requiring in-depth knowledge of the underlying processes. Annoy is a popular choice for tasks that involve indexing high-dimensional data, while Faiss is well-suited for applications that require extremely fast query times and support for large datasets. Both libraries offer a range of features, including support for different distance metrics, query types, and indexing methods. Developers can choose the library that best suits their needs based on their specific use case and dataset characteristics.

While libraries such as Annoy and Faiss are excellent choices for approximate nearest neighbor search in Python, implementing the algorithm from scratch will give you a deeper understanding of how it works. This approach can be useful for debugging, experimenting with different algorithms, and building custom solutions. It’s important to note that implementing ANN from scratch can be time-consuming, and may not always be the most efficient or practical solution. Additionally, it’s worth considering that the popular libraries have been developed and tested by teams of experts, and offer a range of advanced features and optimizations. However, for those who believe in the mantra “code wins arguments,” implementing ANN from scratch can be a valuable exercise in understanding the underlying concepts and algorithms.

From the author of Annoy library – Erik Bernhardsson, I strongly recommend you to read this blog post, about how tree-based a-nn is working. It helped me to understand how the algorithm is working a lot. Then, I decided to implement it from scratch.

Data set

In this study, we are going to use 2-dimensional vectors because we can visualize them. However, vectors are multi-dimensional in real world. For instance, FaceNet produces 128-dimensional vectors and VGG-Face produces 2622-dimensional vectors.

import numpy as np
import matplotlib.pyplot as plt

dimensions = 2
num_instances = 100

# generating the data set
vectors = []
for i in range(0, num_instances):
    x = round(100 * random.random(), 2)
    y = round(100 * random.random(), 2)
    
    vectors.append(np.array([x, y]))

# visualize the data set
fig = plt.figure(figsize=(10, 10))
for x, y in vectors:
    plt.scatter(x, y, c = 'black', marker = 'x')
plt.show()

This will generate 2-dimensional 100 vectors.

Dataset

Data Structure

We are going to split the space half and half according to the randomly selected two vectors. The line splitting these vectors will be stored in a decision tree as a decision rule. Then, we are going to distribute all vectors according to the being on the left or right of that line. Thereafter, we are going to split the space of the subset of vectors recursively.

We can split the space with a line if the vectors are 2-dimensional. But if vectors are n-dimensional, we will be able to split the space with a hyperplane.





So, we can use the following class to construct our tree. Each leaf of our tree will be a Node class and I will connect it to parent Node class’ left or right parameter. Hyperplane will store the equation of the hyperplane splitting the space, value will store the vectors for that level, instances will store the number of vectors for that level.

class Node:
    def __init__(self, hyperplane = None, value = None, id = None, instances = 0):
        self.left = None
        self.right = None
        self.hyperplane = hyperplane
        self.value = value
        self.id = id
        self.instances = instances

Hyperplane

Hyperplane parameter in my Node class will be type of list. Suppose that its value is

[1, 2, 3, 4]

Then the equation of hyperplane will be

x + 2y + 3z = 4

Finding the hyperplane

To determine the hyperplane equidistant from two given n-dimensional vectors. First, calculate the midpoint of the two vectors by averaging their corresponding components. Next, find the direction vector pointing from the first vector to the second vector. Normalize this vector to obtain the unit vector in the same direction. This unit vector serves as a guide to define a normal vector for the hyperplane. With the midpoint and the normal vector in hand, calculate the distance between the midpoint and the hyperplane using the dot product. Finally, formulate the equation of the hyperplane by combining the normal vector components and the distance term. This method ensures the hyperplane is equidistant from the two given vectors in the n-dimensional space.

def find_hyperplane(v1, v2):
    '''
    finds the hyperplane equidistant from
    two given n-dimensional vectors v1 and v2
    '''
    # find the midpoint of two vectors
    midpoint = (v1 + v2) / 2
    
    # find the direction vector from v1 to v2
    direction_vector = v2 - v1
    
    # find the unit vector of the direction vector
    unit_vector = direction_vector / np.linalg.norm(direction_vector)
    
    # define a normal vector to the hyperplane
    normal_vector = unit_vector
    
    # calculate the distance between midpoint and the hyperplane
    distance = np.dot(midpoint, normal_vector)
    
    # define the equation of the hyperplane
    hyperplane = np.concatenate((normal_vector, [distance]))
    
    return hyperplane

Decide a vector is on the left or right of a hyperplane

To determine whether a vector lies on the left or right side of a given hyperplane, one can calculate the signed distance from the vector to the hyperplane. The signed distance is obtained by taking the dot product of the vector and the normal vector of the hyperplane, subtracted by the hyperplane’s constant term. If the resulting signed distance is negative, the vector is considered to be on the left side of the hyperplane. Conversely, if the signed distance is positive, the vector is on the right side. In cases where the signed distance is exactly zero, indicating that the vector lies on the hyperplane, conventionally, it is still treated as being on the right side. This convention helps maintain consistency in determining the orientation relative to the hyperplane.

def is_on_left(v, hyperplane):
    # calculate the signed distance from v to the hyperplane
    signed_distance = np.dot(hyperplane[:-1], v) - hyperplane[-1]
    
    if signed_distance < 0:
        return True
    else:
        return False

Splitting the space recursively

In the process of constructing the entire tree, the split nodes function is pivotal. Within a set of vectors, the algorithm selects two random points and determines the hyperplane that is equidistant from these chosen vectors. Subsequently, the vectors are divided into left and right nodes based on their respective positions relative to the identified hyperplane. To facilitate the recursive building of the tree, a Node class is created, and its left nodes are populated by invoking the split nodes function recursively. The same recursive approach is employed to set the right nodes of the Node class. This recursive partitioning continues until the number of vectors within a given level falls below or equals a predefined threshold, which is set at 5 in this particular experiment. This stepwise process ensures the systematic construction of the tree, with nodes being recursively split until the specified threshold is met.

def split_nodes(vectors, ids = None):
    if ids is None:
        ids = [*range(0, len(vectors))]
    
    # pick two random points
    point_1st_idx = 0; point_2nd_idx = 0
    
    while point_1st_idx == point_2nd_idx:
        point_1st_idx = random.randint(0, len(vectors) - 1)
        point_2nd_idx = random.randint(0, len(vectors) - 1)
    
    v1 = vectors[point_1st_idx]
    v2 = vectors[point_2nd_idx]
    
    # find the hyperplane equidistant from those two vectors
    hyperplane = find_hyperplane(v1, v2)
    hyperplanes.append(hyperplane)
    
    # split vectors into left and right nodes
    left_nodes = []
    right_nodes = []
    
    left_ids = []
    right_ids = []
    
    for idx, vector in enumerate(vectors):
        is_left_node = is_on_left(v=vector, hyperplane=hyperplane)
        
        if is_left_node is True:
            left_nodes.append(vector)
            left_ids.append(ids[idx])
        else:
            right_nodes.append(vector)
            right_ids.append(ids[idx])

    assert len(left_nodes) + len(right_nodes) == len(vectors)
    
    current_node = Node(
        hyperplane=hyperplane,
        value=vectors,
        id=ids,
        instances=len(vectors)
    )
    
    if len(left_nodes) > subset_size:
        current_node.left = split_nodes(
            vectors=left_nodes,
            ids=left_ids
        )
    else:
        current_node.left = Node(
            value=left_nodes,
            id=left_ids,
            instances=len(left_nodes)
        )
    
    if len(right_nodes) > subset_size:
        current_node.right = split_nodes(
            vectors=right_nodes,
            ids=right_ids
        )
    else:
        current_node.right = Node(
            value=right_nodes,
            id=right_ids,
            instances=len(right_nodes)
        )
    
    return current_node

Once our recursive split nodes function is ready, we can construct our tree.

tree = split_nodes(vectors)

Search

Now, we can use the built tree to find the nearest neighbours.





# search nearest neighbours to this vector
v = [50, 50]

# find k nearest neighbours
k = 3

node = tree
while node.instances >= k and node.hyperplane is not None:
   parent = node
   if is_on_left(v, node.hyperplane) is True:
      node = node.left
   else:
      node = node.right

print(f'nearest neighbor vectors: {node.values}')

Visualization

Once we built our tree, we can find the 5 nearest neighbors of a given vector as shown below. Red colored one is the target vector, whereas blue colored ones are the nearest neighbors.

Nearest Neighbor Results

Training Steps

We have interesting visualizations when we visualize the tree building steps. Two randomly selected vectors are shown with red x markers. The hyperplane that equidistant to them are also shown with red line for the current iteration whereas previous iteration’s hyperplanes are shown with black lines. Also, the equation of each hyperplane for the current iteration (red line) is shown on the top of the graph. To sum up, tree can be constructed with 28 steps. However, please consider that you should build this tree as offline.

Step 1
Step 2
Step 3
Step 28

Random Forest

From the nature of the algorithm, we picked 2 random points in each iteration. You may consider to build many trees with random forest algorithm to have a robust tree and get rid of the being lucky or unlucky of random point selection.

Time complexity

Suppose that we have n vectors in our data set. To find the k-nearest neighbors of a given vector, we firstly need to find the distance of given vector to all vectors in our dataset. That requires n calculations. In other words, complexity of this part is O(n) and n is the number of instances in our dataset. Then, we need to sort those n values. Python’s built-in sorting functionality is using Timsort and its complexity is O(n logn). To sum up, we need to perform O(n) + O(n logn) operations with respect to the time complexity.

In our experiment, we had 100 instances. So, we have to perform 300 operations to find the k nearest neigbours with exact nearest neighbor search.

O(100) + O(100 x log100) = O(100) + O(100 x 2) = O(100) + O(200) = 300

On the other hand, you can list the 5 nearest neighbors in 4 steps once you built the tree. This is 75x faster even for a small sized dataset. Of course, with exact nearest neighbor approach, you do not have to build a tree. Approximate nearest neighbor comes with space complexity but most of the time it is worth!

root: 100 instances
β”œβ”€β”€ go to left: 47 instances
β”‚   β”œβ”€β”€ go to left: 28 instances
β”‚   β”‚   β”œβ”€β”€go to left: 11 instance
β”‚   β”‚   β”‚   β”œβ”€β”€go to left: 5 instances

Conclusion

In this blog post, we’ve covered the basics of approximate nearest neighbor search in Python, including the mathematical concepts behind the algorithm and a Python implementation from scratch. We’ve also introduced two popular libraries for ANN – Spotify’s Annoy and Facebook’s Faiss – and discussed their strengths and weaknesses. ANN algorithms are a powerful tool for large-scale data analysis, allowing us to trade off some degree of accuracy for faster query times. However, it’s important to keep in mind that the level of approximation will depend on a heuristic approach and we will not have the exact nearest neighbours always. With the knowledge and tools presented in this blog post, you should be well on your way to performing efficient nearest neighbor search in your Python-based data analysis projects.

I pushed the source code of this study to GitHub. If you do like this work, please star⭐ its repo.






Support this blog if you do like!

Buy me a coffee      Buy me a coffee