Source code for cam.sgnmt.misc.trie

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains ``SimpleTrie`` which is a generic trie
implementation based on strings of integers.
"""

from operator import itemgetter

[docs]class SimpleNode: """Helper class representing a node in a ``SimpleTrie`` """ def __init__(self): """Creates an empty node without children. """ self.edges = {} # outgoing edges with terminal symbols self.element = None # Elements stored at this node
[docs]class SimpleTrie: """This is a very simple Trie implementation. It is simpler than the one in ``cam.sgnmt.predictors.grammar`` because it does not support non-terminals or removal. The only supported operations are ``add`` and ``get``, but those are implemented very efficiently. For many applications (e.g. the cache in the greedy heuristic) this is already enough. The implementation also supports keys in sparse representation, in which most of the elements in the sequence are zero (see ``add_sparse``, ``get_sparse``, and ``nearest_sparse``. In this case, the key is a list of tuples [(dim1,val1),...(dimN,valN)]. Internally, we store them as sequence "dim1 val1 dim2 val2..." Note that we assume that the tuples are ordered by dimension! """ def __init__(self): """Creates an empty Trie data structure. """ self.root = SimpleNode() def _get_node(self, seq): """Get the ```SimpleNode``` for the given sequence ``seq``. If the path for ``seq`` does not exist yet in the Trie, add it and return a reference to the newly created node. """ cur_node = self.root for token_id in seq: children = cur_node.edges if not token_id in children: children[token_id] = SimpleNode() cur_node = children[token_id] return cur_node
[docs] def add(self, seq, element): """Add an element to the Trie for the key ``seq``. If ``seq`` already exists, override. Args: seq (list): Key element (object): The object to store for key ``seq`` """ self._get_node(seq).element = element
[docs] def get(self, seq): """Retrieve the element for a key ``seq``. Args: seq (list): Query key Returns: object. The element which has been added along with ``seq`` or ``None`` if the key does not exist. """ return self._get_node(seq).element
[docs] def get_prefix(self, seq): """Get the key in the Trie with the longest common prefix with ``seq``. Args: seq (list): Query sequence Returns: list. The longest key in the Trie which is a prefix of ``seq``. """ cur_node = self.root prefix = [] best_prefix = [] for token_id in seq: children = cur_node.edges if not token_id in children: break prefix.append(token_id) cur_node = children[token_id] if cur_node.element: best_prefix = list(prefix) return best_prefix
def _sparse2seq(self, key): """Transforms a key in sparse representation to a sequence which can be used as key in the Trie. """ seq = [] for (d,v) in key: seq.append(d) seq.append(v) return seq
[docs] def add_sparse(self, key, element): """Adds an element with a key in sparse representation. Args: seq (list): Sparse key (list of tuples) element (object): The object to store for key ``seq`` """ self.add(self._sparse2seq(key), element)
[docs] def get_sparse(self, key, element): """Retrieves an element with a key in sparse representation. Args: seq (list). Query key in sparse format Returns: object. The element which has been added along with ``seq`` or ``None`` if the key does not exist. """ return self.get(self._sparse2seq(key), element)
[docs] def nearest_sparse(self, query): """This method returns the element in the Trie with the closest key to ``query`` in terms of Euclidean distance. The efficiency relies on sparseness: The more zeros in the vector, the more efficient. If the Trie contains an exact match, this method runs linear in the length of the query (i.e. independent of number of elements in the Trie). Args: query (list): Query key in sparse format Returns: Tuple. (object,dist) pair with the nearest element to ``query`` in terms of L2 norm and the squared L2 distance. """ self.best_dist = float("inf") self.best_element = None self._register_best_element = self._register_best_element_single self._nearest_sparse_recursive(self._sparse2seq(query), self.root, 0.0) return self.best_element,self.best_dist
[docs] def n_nearest_sparse(self, query, n=1): """This method returns the n element in the Trie with the closest key to ``query`` in terms of Euclidean distance. The efficiency relies on sparseness: The more zeros in the vector, the more efficient. Args: query (list): Query key in sparse format n (int): Number of elements to retrieve Returns: List. List of (object,dist) pairs with the nearest element to ``query`` in terms of L2 norm and the squared L2 distance. """ if n <= 1: return [self.nearest_sparse(query)] self.best_dist = float("inf") self.best_elements = [(None, self.best_dist)] # guardian element self.n = n self._register_best_element = self._register_best_element_multi self._nearest_sparse_recursive(self._sparse2seq(query), self.root, 0.0) return self.best_elements
def _register_best_element_single(self, dist, el): self.best_dist = dist self.best_element = el def _register_best_element_multi(self, dist, el): self.best_elements = self.best_elements[:self.n-1] + [(el, dist)] self.best_elements.sort(key=itemgetter(1)) self.best_dist = self.best_elements[-1][1] def _nearest_sparse_recursive(self, seq, root, dist): if dist > self.best_dist: return if not seq: self._dfs_for_nearest(root, dist) return if root.element: add_dist = sum([seq[idx]**2 for idx in range(1, len(seq), 2)]) if dist + add_dist < self.best_dist: self._register_best_element(dist + add_dist, root.element) dim = seq[0] # Explore close matches first children = sorted(root.edges, key=lambda el: (el-dim)**2) for child_dim in children: child_node = root.edges[child_dim] next_seq = seq[0:] next_dist = dist try: while child_dim > next_seq[0]: next_dist += next_seq[1]**2 next_seq = next_seq[2:] if child_dim == next_seq[0]: # Exact match :) c_discount = next_seq[1] next_seq = next_seq[2:] else: c_discount = 0.0 for c,node in child_node.edges.items(): self._nearest_sparse_recursive(next_seq, node, next_dist+(c-c_discount)**2) except IndexError: for c,node in child_node.edges.items(): self._dfs_for_nearest(node, next_dist + c*c) def _dfs_for_nearest(self, root, dist): """Scans the subtree under ``root`` for nearest elements. ``dist`` is the distance which has already been accumulated. """ if dist > self.best_dist: return if root.element: self._register_best_element(dist, root.element) return for child in root.edges.values(): for c,next_child in child.edges.items(): self._dfs_for_nearest(next_child, dist + c*c)