Source code for cam.sgnmt.predictors.automata

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module encapsulates the predictor interface to OpenFST. This
module depends on OpenFST. To enable Python support in OpenFST, use a 
recent version (>=1.5.4) and compile with ``--enable_python``. 
Further information can be found here:

http://www.openfst.org/twiki/bin/view/FST/PythonExtension 

This file includes the fst, nfst, and rtn predictors.

Note: If we use arc weights in FSTs, we multiply them by -1 as 
everything in SGNMT is logprob, not -logprob as in FSTs log 
or tropical semirings. You can disable this behavior with --fst_to_log

Note2: The FSTs and RTNs are assumed to have both <S> and </S>. This 
has compatibility reasons, as lattices generated by HiFST have these
symbols.
"""

import glob
import logging
import os
import sys

from cam.sgnmt import utils
from cam.sgnmt.predictors.core import Predictor
from cam.sgnmt.utils import w2f, load_fst

try:
    import pywrapfst as fst
except ImportError:
    try:
        import openfst_python as fst
    except ImportError:
        pass # Deal with it in decode.py


EPS_ID = 0
"""OpenFST's reserved ID for epsilon arcs. """


[docs]class FstPredictor(Predictor): """This predictor can read determinized translation lattices. The predictor state consists of the current node. This is unique as the lattices are determinized. """ def __init__(self, fst_path, use_weights, normalize_scores, skip_bos_weight = True, to_log = True): """Creates a new fst predictor. Args: fst_path (string): Path to the FST file use_weights (bool): If false, replace all arc weights with 0 (=log 1). normalize_scores (bool): If true, we normalize the weights on all outgoing arcs such that they sum up to 1 skip_bos_weight (bool): Add the score at the <S> arc to the </S> arc if this is false. This results in scores consistent with OpenFST's replace operation, as <S> scores are normally ignored by SGNMT. to_log (bool): SGNMT uses normal log probs (scores) while arc weights in FSTs normally have cost (i.e. neg. log values) semantics. Therefore, if true, we multiply arc weights by -1. """ super(FstPredictor, self).__init__() self.fst_path = fst_path self.weight_factor = -1.0 if to_log else 1.0 self.use_weights = use_weights self.normalize_scores = normalize_scores self.cur_fst = None self.add_bos_to_eos_score = not skip_bos_weight self.cur_node = -1
[docs] def get_unk_probability(self, posterior): """Returns negative infinity if UNK is not in the lattice. Otherwise, return UNK score. Returns: float. Negative infinity """ return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF)
[docs] def predict_next(self): """Uses the outgoing arcs from the current node to build up the scores for the next word. Returns: dict. Set of words on outgoing arcs from the current node together with their scores, or an empty set if we currently have no active node or fst. """ if self.cur_node < 0: return {} scores = {arc.olabel: self.weight_factor*w2f(arc.weight) for arc in self.cur_fst.arcs(self.cur_node)} if utils.EOS_ID in scores and self.add_bos_to_eos_score: scores[utils.EOS_ID] += self.bos_score return self.finalize_posterior(scores, self.use_weights, self.normalize_scores)
[docs] def initialize(self, src_sentence): """Loads the FST from the file system and consumes the start of sentence symbol. Args: src_sentence (list): Not used """ self.cur_fst = load_fst(utils.get_path(self.fst_path, self.current_sen_id+1)) self.cur_node = self.cur_fst.start() if self.cur_fst else None self.bos_score = self.consume(utils.GO_ID) if not self.bos_score: # Override None self.bos_score = 0.0 if self.cur_node is None: logging.warn("The lattice for sentence %d does not contain any " "valid path. Please double-check that the lattice " "is not empty and that paths contain the begin-of-" "sentence symbol %d. If you are using a different " "begin-of-sentence symbol, double-check --indexing_" "scheme." % (self.current_sen_id+1, utils.GO_ID))
[docs] def consume(self, word): """Updates the current node by following the arc labelled with ``word``. If there is no such arc, we set ``cur_node`` to -1, indicating that the predictor is in an invalid state. In this case, all subsequent ``predict_next`` calls will return the empty set. Args: word (int): Word on an outgoing arc from the current node Returns: float. Weight on the traversed arc """ if self.cur_node < 0: return from_state = self.cur_node self.cur_node = None unk_arc = None for arc in self.cur_fst.arcs(from_state): if arc.olabel == word: self.cur_node = arc.nextstate return self.weight_factor*w2f(arc.weight) elif arc.olabel == utils.UNK_ID: unk_arc = arc if unk_arc is not None: self.cur_node = unk_arc.nextstate
[docs] def get_state(self): """Returns the current node. """ return self.cur_node
[docs] def set_state(self, state): """Sets the current node. """ self.cur_node = state
[docs] def initialize_heuristic(self, src_sentence): """Creates a matrix of shortest distances between nodes. """ self.distances = fst.shortestdistance(self.cur_fst, reverse=True)
[docs] def estimate_future_cost(self, hypo): """The FST predictor comes with its own heuristic function. We use the shortest path in the fst as future cost estimator. """ if not self.cur_node: return 0.0 last_word = hypo.trgt_sentence[-1] for arc in self.cur_fst.arcs(self.cur_node): if arc.olabel == last_word: return w2f(self.distances[arc.nextstate]) return 0.0
[docs] def is_equal(self, state1, state2): """Returns true if the current node is the same """ return state1 == state2
[docs]class NondeterministicFstPredictor(Predictor): """This predictor can handle non-deterministic translation lattices. In contrast to the fst predictor for deterministic lattices, we store a set of nodes which are all reachable from the start node through the current history. """ def __init__(self, fst_path, use_weights, normalize_scores, skip_bos_weight = True, to_log = True): """Creates a new nfst predictor. Args: fst_path (string): Path to the FST file use_weights (bool): If false, replace all arc weights with 0 (=log 1). normalize_scores (bool): If true, we normalize the weights on all outgoing arcs such that they sum up to 1 skip_bos_weight (bool): If true, set weights on <S> arcs to 0 (= log1) to_log (bool): SGNMT uses normal log probs (scores) while arc weights in FSTs normally have cost (i.e. neg. log values) semantics. Therefore, if true, we multiply arc weights by -1. """ super(NondeterministicFstPredictor, self).__init__() self.fst_path = fst_path self.weight_factor = -1.0 if to_log else 1.0 self.score_max_func = max if to_log else min self.use_weights = use_weights self.skip_bos_weight = skip_bos_weight self.normalize_scores = normalize_scores self.cur_fst = None self.cur_nodes = []
[docs] def get_unk_probability(self, posterior): """Always returns negative infinity: Words outside the translation lattice are not possible according to this predictor. Returns: float. Negative infinity """ return utils.NEG_INF
[docs] def predict_next(self): """Uses the outgoing arcs from all current node to build up the scores for the next word. This method does not follow epsilon arcs: ``consume`` updates ``cur_nodes`` such that all reachable arcs with word ids are connected directly with a node in ``cur_nodes``. If there are multiple arcs with the same word, we use the log sum of the arc weights as score. Returns: dict. Set of words on outgoing arcs from the current node together with their scores, or an empty set if we currently have no active nodes or fst. """ scores = {} for weight,node in self.cur_nodes: for arc in self.cur_fst.arcs(node): if arc.olabel != EPS_ID: score = weight + self.weight_factor*w2f(arc.weight) if arc.olabel in scores: scores[arc.olabel] = self.score_max_func( scores[arc.olabel], score) else: scores[arc.olabel] = score return self.finalize_posterior(scores, self.use_weights, self.normalize_scores)
[docs] def initialize(self, src_sentence): """Loads the FST from the file system and consumes the start of sentence symbol. Args: src_sentence (list): Not used """ self.cur_fst = load_fst(utils.get_path(self.fst_path, self.current_sen_id+1)) self.cur_nodes = [] if self.cur_fst: self.cur_nodes = self._follow_eps({self.cur_fst.start(): 0.0}) self.consume(utils.GO_ID) if not self.cur_nodes: logging.warn("The lattice for sentence %d does not contain any " "valid path. Please double-check that the lattice " "is not empty and that paths start with the begin-of-" "sentence symbol." % (self.current_sen_id+1))
[docs] def consume(self, word): """Updates the current nodes by searching for all nodes which are reachable from the current nodes by a path consisting of any number of epsilons and exactly one ``word`` label. If there is no such arc, we set the predictor in an invalid state. In this case, all subsequent ``predict_next`` calls will return the empty set. Args: word (int): Word on an outgoing arc from the current node """ d_unconsumed = {} # Collect distances to nodes reachable by word for weight,node in self.cur_nodes: for arc in self.cur_fst.arcs(node): if arc.olabel == word: next_node = arc.nextstate next_score = weight + self.weight_factor*w2f(arc.weight) if d_unconsumed.get(next_node, utils.NEG_INF) < next_score: d_unconsumed[next_node] = next_score # Subtract the word score from the last predict_next consumed_score = self.score_max_func(d_unconsumed.values()) \ if (word != utils.GO_ID or self.skip_bos_weight) else 0.0 # Add epsilon reachable states self.cur_nodes = self._follow_eps({node: score - consumed_score for node,score in d_unconsumed.items()})
def _follow_eps(self, roots): """BFS to find nodes reachable from root through eps arcs. This traversal strategy is efficient if the triangle inquality holds for weights in the graphs, i.e. for all vertices v1,v2,v3: (v1,v2),(v2,v3),(v1,v3) in E => d(v1,v2)+d(v2,v3) >= d(v1,v3). The method still returns the correct results if the triangle inequality does not hold, but edges may be traversed multiple times which makes it more inefficient. """ open_nodes = dict(roots) d = {} visited = dict(roots) while open_nodes: next_open = {} for node,score in open_nodes.items(): has_noneps = False for arc in self.cur_fst.arcs(node): if arc.olabel == EPS_ID: next_node = arc.nextstate next_score = score + self.weight_factor*w2f(arc.weight) if visited.get(next_node, utils.NEG_INF) < next_score: visited[next_node] = next_score next_open[next_node] = next_score else: has_noneps = True if has_noneps: d[node] = score open_nodes = next_open return [(weight, node) for node, weight in d.items()]
[docs] def get_state(self): """Returns the set of current nodes """ return self.cur_nodes
[docs] def set_state(self, state): """Sets the set of current nodes """ self.cur_nodes = state
[docs] def initialize_heuristic(self, src_sentence): """Creates a matrix of shortest distances between all nodes """ self.distances = fst.shortestdistance(self.cur_fst, reverse=True)
[docs] def estimate_future_cost(self, hypo): """The FST predictor comes with its own heuristic function. We use the shortest path in the fst as future cost estimator. """ last_word = hypo.trgt_sentence[-1] dists = [] for n in self.cur_nodes: for arc in self.cur_fst[n].arcs: if arc.olabel == last_word: dists.append(w2f(self.distances[arc.nextstate])) break return 0.0 if not dists else min(dists)
[docs] def is_equal(self, state1, state2): """Returns true if the current nodes are the same """ return sorted([n for _,n in state1]) == sorted([n for _,n in state2])
[docs]class RtnPredictor(Predictor): """Predictor for RTNs (recurrent transition networks). This predictor assumes a directory structure as produced by HiFST. You can use this predictor for non-deterministic lattices too. This implementation supports late expansion: RTNs are only expanded as far as necessary to retrieve all currently reachable states. ``cur_nodes`` contains the accumulated weights from the last consumed word (if ambiguous, the largest) This implementation does not maintain a list of active nodes like the other automata predictors. Instead, we store the current history and search for the active nodes at each expansion. This is more expensive, but fstreplace might change state IDs so a list of active nodes might get corrupted. Note that this predictor does not support FSTs in gzip format. """ def __init__(self, rtn_path, use_weights, normalize_scores, to_log = True, minimize_rtns = False, rmeps = True): """Creates a new RTN predictor. Args: rtn_path (string): Path to the RTN directory use_weights (bool): If false, replace all arc weights with 0 (=log 1). normalize_scores (bool): If true, we normalize the weights on all outgoing arcs such that they sum up to 1 to_log (bool): SGNMT uses normal log probs (scores) while arc weights in FSTs normally have cost (i.e. neg. log values) semantics. Therefore, if true, we multiply arc weights by -1. minimize_rtns (bool): Minimize the FST after each replace operation rmeps (bool): Remove epsilons in the FST after each replace operation """ super(RtnPredictor, self).__init__() self.root_path = rtn_path self.minimize_rtns = minimize_rtns self.rmeps = rmeps self.use_weights = use_weights self.normalize_scores = normalize_scores self.weight_factor = -1.0 if to_log else 1.0 self.cur_fst = None # current root fst start_id = '1' try: with open("%s/ntmap" % self.root_path) as f: ntmap = dict(line.strip().split(None, 1) for line in f) start_id = ntmap['S'] except: logging.warn("Could not find NT S in ntmap. Assuming its ID 1") self.root_fst_prefix = "1%s000" % start_id.zfill(3)
[docs] def get_unk_probability(self, posterior): """Always returns negative infinity: Words outside the RTN are not possible according to this predictor. Returns: float. Negative infinity """ return utils.NEG_INF
[docs] def initialize(self, src_sentence): """Loads the root RTN and consumes the start of sentence symbol. Args: src_sentence (list): Not used """ try: file_name = "%s/%d.fst" % (self.root_path, self.current_sen_id+1) if not os.access(file_name, os.R_OK): # Find root FST search_pattern = '%s/%d/%s*.fst' % (self.root_path, self.current_sen_id+1, self.root_fst_prefix) candidates = glob.glob(search_pattern) if not candidates: logging.error("Could not find root fst in %s" % search_pattern) self.cur_fst = None return if len(candidates) > 1: logging.warn("Ambiguous root fst for %s. Take the one " "with largest span." % search_pattern) candidates = sorted(candidates) file_name = candidates[-1] self.cur_fst = fst.Fst.read(file_name) logging.debug("Read (root)fst from %s" % file_name) except Exception as e: logging.error("%s error reading fst from %s: %s" % (sys.exc_info()[1], file_name, e)) self.cur_fst = None finally: self.cur_history = [] self.sub_fsts = {} self.consume(utils.GO_ID)
[docs] def expand_rtn(self, func): """This method expands the RTN as far as necessary. This means that the RTN is expanded s.t. we can build the posterior for ``cur_history``. In practice, this means that we follow all epsilon edges and replaces all NT edges until all paths with the prefix ``cur_history`` in the RTN have at least one more terminal token. Then, we apply ``func`` to all reachable nodes. """ updated = True while updated: updated = False label_fst_map = {} self.visited_nodes = {} self.cur_fst.arcsort(sort_type="olabel") self.add_to_label_fst_map_recursive(label_fst_map, {}, self.cur_fst.start(), 0.0, self.cur_history, func) if label_fst_map: logging.debug("Replace %d NT arcs for history %s" % ( len(label_fst_map), self.cur_history)) # First in the list is the root FST and label replaced_fst = fst.replace( [(len(label_fst_map) + 2000000000, self.cur_fst)] + [(nt_label, f) for (nt_label, f) in label_fst_map.items()], epsilon_on_replace=True) self.cur_fst = replaced_fst updated = True if self.rmeps or self.minimize_rtns: self.cur_fst.rmepsilon() if self.minimize_rtns: tmp = fst.determinize(self.cur_fst.determinize) self.cur_fst = tmp self.cur_fst.minimize()
[docs] def add_to_label_fst_map_recursive(self, label_fst_map, visited_nodes, root_node, acc_weight, history, func): """Adds arcs to ``label_fst_map`` if they are labeled with an NT symbol and reachable from ``root_node`` via ``history``. Note: visited_nodes is maintained for each history separately """ if root_node in visited_nodes: # This introduces some error as we take the score of the first best # path with a certain history, not the globally best path. For now, # this error should not be significant return visited_nodes[root_node] = True for arc in self.cur_fst.arcs(root_node): arc_acc_weight = acc_weight + self.weight_factor*w2f(arc.weight) if arc.olabel == EPS_ID: # Follow epsilon edges self.add_to_label_fst_map_recursive(label_fst_map, visited_nodes, arc.nextstate, arc_acc_weight, history, func) elif not history: if self.is_nt_label(arc.olabel): # Add to label_fst_map replace_label = len(label_fst_map) + 2000000000 label_fst_map[replace_label] = self.get_sub_fst( arc.olabel) arc.ilabel = replace_label arc.olabel = replace_label else: # This is a regular arc and we have no history left func(arc.nextstate, arc.olabel, arc_acc_weight) # apply func elif arc.olabel == history[0]: # history is not empty self.add_to_label_fst_map_recursive(label_fst_map, {}, arc.nextstate, arc_acc_weight, history[1:], func) elif arc.olabel > history[0]: # FST is arc sorted, we can stop here break
[docs] def is_nt_label(self, label): """Returns true if ``label`` is a non-terminal. """ s = str(label) return len(s) == 10 and s[0] == '1'
[docs] def get_sub_fst(self, fst_id): """ Load sub fst from the file system or the cache """ if fst_id in self.sub_fsts: return self.sub_fsts[fst_id] sub_fst_path = "%s/%d/%d.fst" % (self.root_path, self.current_sen_id+1, fst_id) try: sub_fst = fst.Fst.read(sub_fst_path) logging.debug("Read sub fst from %s" % sub_fst_path) self.sub_fsts[fst_id] = sub_fst return sub_fst except Exception as e: logging.error("%s error reading sub fst from %s: %s" % (sys.exc_info()[1], sub_fst_path, e))
def _add_to_cur_posterior(self, node, label, weight): """Can be used as ``func`` argument in ``expand_rtn`` to build up the posterior for the next target token in ``predict_next`` """ self.cur_posterior[label] = max(self.cur_posterior.get(label, utils.NEG_INF), weight)
[docs] def predict_next(self): """Expands RTN as far as possible and uses the outgoing edges from nodes reachable by the current history to build up the posterior for the next word. If there are no such nodes or arcs, or no root FST is loaded, return the empty set. """ if not self.cur_fst: return {} self.cur_posterior = {} self.expand_rtn(self._add_to_cur_posterior) return self.finalize_posterior(self.cur_posterior, self.use_weights, self.normalize_scores)
[docs] def consume(self, word): """Adds ``word`` to the current history. """ self.cur_history.append(word)
[docs] def get_state(self): """Returns the current history. """ return self.cur_history
[docs] def set_state(self, state): """Sets the current history. """ self.cur_history = state