# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module encapsulates the predictor interface to OpenFST. This
module depends on OpenFST. To enable Python support in OpenFST, use a 
recent version (>=1.5.4) and compile with ``--enable_python``. 
Further information can be found here:
http://www.openfst.org/twiki/bin/view/FST/PythonExtension 
This file includes the fst, nfst, and rtn predictors.
Note: If we use arc weights in FSTs, we multiply them by -1 as 
everything in SGNMT is logprob, not -logprob as in FSTs log 
or tropical semirings. You can disable this behavior with --fst_to_log
Note2: The FSTs and RTNs are assumed to have both <S> and </S>. This 
has compatibility reasons, as lattices generated by HiFST have these
symbols.
"""
import glob
import logging
import os
import sys
from cam.sgnmt import utils
from cam.sgnmt.predictors.core import Predictor
from cam.sgnmt.utils import w2f, load_fst
try:
    import pywrapfst as fst
except ImportError:
    try:
        import openfst_python as fst
    except ImportError:
        pass # Deal with it in decode.py
EPS_ID = 0
"""OpenFST's reserved ID for epsilon arcs. """
[docs]class FstPredictor(Predictor):
    """This predictor can read determinized translation lattices. The
    predictor state consists of the current node. This is unique as the
    lattices are determinized.
    """
    
    def __init__(self,
                 fst_path,
                 use_weights,
                 normalize_scores,
                 skip_bos_weight = True,
                 to_log = True):
        """Creates a new fst predictor.
        
        Args:
            fst_path (string): Path to the FST file
            use_weights (bool): If false, replace all arc weights with
                                0 (=log 1).
            normalize_scores (bool): If true, we normalize the weights
                                     on all outgoing arcs such that
                                     they sum up to 1
            skip_bos_weight (bool): Add the score at the <S> arc to the
                                    </S> arc if this is false. This results
                                    in scores consistent with 
                                    OpenFST's replace operation,
                                    as <S> scores are normally
                                    ignored by SGNMT.
            to_log (bool): SGNMT uses normal log probs (scores) while
                           arc weights in FSTs normally have cost (i.e.
                           neg. log values) semantics. Therefore, if
                           true, we multiply arc weights by -1.
        """
        super(FstPredictor, self).__init__()
        self.fst_path = fst_path
        self.weight_factor = -1.0 if to_log else 1.0
        self.use_weights = use_weights
        self.normalize_scores = normalize_scores
        self.cur_fst = None
        self.add_bos_to_eos_score = not skip_bos_weight
        self.cur_node = -1
        
[docs]    def get_unk_probability(self, posterior):
        """Returns negative infinity if UNK is not in the lattice.
        Otherwise, return UNK score.
        
        Returns:
            float. Negative infinity
        """
        return utils.common_get(posterior, utils.UNK_ID, utils.NEG_INF) 
    
[docs]    def predict_next(self):
        """Uses the outgoing arcs from the current node to build up the
        scores for the next word.
        
        Returns:
            dict. Set of words on outgoing arcs from the current node
            together with their scores, or an empty set if we currently
            have no active node or fst.
        """
        if self.cur_node < 0:
            return {}
        scores = {arc.olabel: self.weight_factor*w2f(arc.weight)
                for arc in self.cur_fst.arcs(self.cur_node)}
        if utils.EOS_ID in scores and self.add_bos_to_eos_score:
            scores[utils.EOS_ID] += self.bos_score
        return self.finalize_posterior(scores,
                self.use_weights, self.normalize_scores) 
    
[docs]    def initialize(self, src_sentence):
        """Loads the FST from the file system and consumes the start
        of sentence symbol. 
        
        Args:
            src_sentence (list):  Not used
        """
        self.cur_fst = load_fst(utils.get_path(self.fst_path,
                                               self.current_sen_id+1))
        self.cur_node = self.cur_fst.start() if self.cur_fst else None
        self.bos_score = self.consume(utils.GO_ID)
        if not self.bos_score: # Override None
            self.bos_score = 0.0
        if self.cur_node is None:
            logging.warn("The lattice for sentence %d does not contain any "
                         "valid path. Please double-check that the lattice "
                         "is not empty and that paths contain the begin-of-"
                         "sentence symbol %d. If you are using a different "
                         "begin-of-sentence symbol, double-check --indexing_"
                         "scheme." % (self.current_sen_id+1, utils.GO_ID)) 
    
[docs]    def consume(self, word):
        """Updates the current node by following the arc labelled with
        ``word``. If there is no such arc, we set ``cur_node`` to -1,
        indicating that the predictor is in an invalid state. In this
        case, all subsequent ``predict_next`` calls will return the
        empty set.
        
        Args:
            word (int): Word on an outgoing arc from the current node
        
        Returns:
            float. Weight on the traversed arc
        """
        if self.cur_node < 0:
            return
        from_state = self.cur_node
        self.cur_node = None
        unk_arc = None
        for arc in self.cur_fst.arcs(from_state):
            if arc.olabel == word:
                self.cur_node = arc.nextstate
                return self.weight_factor*w2f(arc.weight)
            elif arc.olabel == utils.UNK_ID:
                unk_arc = arc
        if unk_arc is not None:
            self.cur_node = unk_arc.nextstate 
    
[docs]    def get_state(self):
        """Returns the current node. """
        return self.cur_node 
    
[docs]    def set_state(self, state):
        """Sets the current node. """
        self.cur_node = state 
[docs]    def initialize_heuristic(self, src_sentence):
        """Creates a matrix of shortest distances between nodes. """
        self.distances = fst.shortestdistance(self.cur_fst, reverse=True) 
    
[docs]    def estimate_future_cost(self, hypo):
        """The FST predictor comes with its own heuristic function. We
        use the shortest path in the fst as future cost estimator. """
        if not self.cur_node:
            return 0.0
        last_word = hypo.trgt_sentence[-1]
        for arc in self.cur_fst.arcs(self.cur_node):
            if arc.olabel == last_word:
                return w2f(self.distances[arc.nextstate])
        return 0.0 
    
[docs]    def is_equal(self, state1, state2):
        """Returns true if the current node is the same """
        return state1 == state2  
[docs]class NondeterministicFstPredictor(Predictor):
    """This predictor can handle non-deterministic translation 
    lattices. In contrast to the fst predictor for deterministic
    lattices, we store a set of nodes which are all reachable from
    the start node through the current history.
    """
    
    def __init__(self, 
                 fst_path, 
                 use_weights, 
                 normalize_scores, 
                 skip_bos_weight = True, 
                 to_log = True):
        """Creates a new nfst predictor.
        
        Args:
            fst_path (string): Path to the FST file
            use_weights (bool): If false, replace all arc weights with
                                0 (=log 1).
            normalize_scores (bool): If true, we normalize the weights
                                     on all outgoing arcs such that
                                     they sum up to 1
            skip_bos_weight (bool): If true, set weights on <S> arcs
                                    to 0 (= log1)
            to_log (bool): SGNMT uses normal log probs (scores) while
                           arc weights in FSTs normally have cost (i.e.
                           neg. log values) semantics. Therefore, if
                           true, we multiply arc weights by -1.
        """
        super(NondeterministicFstPredictor, self).__init__()
        self.fst_path = fst_path
        self.weight_factor = -1.0 if to_log else 1.0
        self.score_max_func = max if to_log else min
        self.use_weights = use_weights
        self.skip_bos_weight = skip_bos_weight
        self.normalize_scores = normalize_scores
        self.cur_fst = None
        self.cur_nodes = []
        
[docs]    def get_unk_probability(self, posterior):
        """Always returns negative infinity: Words outside the 
        translation lattice are not possible according to this
        predictor.
        
        Returns:
            float. Negative infinity
        """
        return utils.NEG_INF  
    
[docs]    def predict_next(self):
        """Uses the outgoing arcs from all current node to build up the
        scores for the next word. This method does not follow epsilon
        arcs: ``consume`` updates ``cur_nodes`` such that all reachable
        arcs with word ids are connected directly with a node in
        ``cur_nodes``. If there are multiple arcs with the same word,
        we use the log sum of the arc weights as score.
        
        Returns:
            dict. Set of words on outgoing arcs from the current node
            together with their scores, or an empty set if we currently
            have no active nodes or fst.
        """
        scores = {}
        for weight,node in self.cur_nodes:
            for arc in self.cur_fst.arcs(node): 
                if arc.olabel != EPS_ID:
                    score = weight + self.weight_factor*w2f(arc.weight) 
                    if arc.olabel in scores:
                        scores[arc.olabel] = self.score_max_func(
                                        scores[arc.olabel], score)
                    else:
                        scores[arc.olabel] = score 
        return self.finalize_posterior(scores,
                self.use_weights, self.normalize_scores) 
    
[docs]    def initialize(self, src_sentence):
        """Loads the FST from the file system and consumes the start
        of sentence symbol. 
        
        Args:
            src_sentence (list):  Not used
        """
        self.cur_fst = load_fst(utils.get_path(self.fst_path,
                                               self.current_sen_id+1))
        self.cur_nodes = []
        if self.cur_fst:
            self.cur_nodes = self._follow_eps({self.cur_fst.start(): 0.0})
        self.consume(utils.GO_ID)
        if not self.cur_nodes:
            logging.warn("The lattice for sentence %d does not contain any "
                         "valid path. Please double-check that the lattice "
                         "is not empty and that paths start with the begin-of-"
                         "sentence symbol." % (self.current_sen_id+1)) 
    
[docs]    def consume(self, word):
        """Updates the current nodes by searching for all nodes which
        are reachable from the current nodes by a path consisting of 
        any number of epsilons and exactly one ``word`` label. If there
        is no such arc, we set the predictor in an invalid state. In 
        this case, all subsequent ``predict_next`` calls will return 
        the empty set.
        
        Args:
            word (int): Word on an outgoing arc from the current node
        """
        d_unconsumed = {}
        # Collect distances to nodes reachable by word
        for weight,node in self.cur_nodes:
            for arc in self.cur_fst.arcs(node):
                if arc.olabel == word:
                    next_node = arc.nextstate
                    next_score = weight + self.weight_factor*w2f(arc.weight)
                    if d_unconsumed.get(next_node, utils.NEG_INF) < next_score:
                        d_unconsumed[next_node] = next_score
        # Subtract the word score from the last predict_next 
        consumed_score = self.score_max_func(d_unconsumed.values()) \
             
if (word != utils.GO_ID or self.skip_bos_weight) else 0.0
        # Add epsilon reachable states
        self.cur_nodes = self._follow_eps({node: score - consumed_score
                    for node,score in d_unconsumed.items()}) 
    
    def _follow_eps(self, roots):
        """BFS to find nodes reachable from root through eps arcs. This
        traversal strategy is efficient if the triangle inquality holds 
        for weights in the graphs, i.e. for all vertices v1,v2,v3: 
        (v1,v2),(v2,v3),(v1,v3) in E => d(v1,v2)+d(v2,v3) >= d(v1,v3).
        The method still returns the correct results if the triangle
        inequality does not hold, but edges may be traversed multiple
        times which makes it more inefficient.
        """
        open_nodes = dict(roots)
        d = {}
        visited = dict(roots)
        while open_nodes:
            next_open = {}
            for node,score in open_nodes.items():
                has_noneps = False
                for arc in self.cur_fst.arcs(node):
                    if arc.olabel == EPS_ID:
                        next_node = arc.nextstate
                        next_score = score + self.weight_factor*w2f(arc.weight)
                        if visited.get(next_node, utils.NEG_INF) < next_score:
                            visited[next_node] = next_score
                            next_open[next_node] = next_score
                    else:
                        has_noneps = True
                if has_noneps:
                    d[node] = score
            open_nodes = next_open
        return [(weight, node) for node, weight in d.items()]
        
[docs]    def get_state(self):
        """Returns the set of current nodes """
        return self.cur_nodes 
    
[docs]    def set_state(self, state):
        """Sets the set of current nodes """
        self.cur_nodes = state 
[docs]    def initialize_heuristic(self, src_sentence):
        """Creates a matrix of shortest distances between all nodes """
        self.distances = fst.shortestdistance(self.cur_fst, reverse=True) 
    
[docs]    def estimate_future_cost(self, hypo):
        """The FST predictor comes with its own heuristic function. We
        use the shortest path in the fst as future cost estimator. """
        last_word = hypo.trgt_sentence[-1]
        dists = []
        for n in self.cur_nodes:
            for arc in self.cur_fst[n].arcs:
                if arc.olabel == last_word:
                    dists.append(w2f(self.distances[arc.nextstate]))
                    break
        return 0.0 if not dists else min(dists) 
    
[docs]    def is_equal(self, state1, state2):
        """Returns true if the current nodes are the same """
        return sorted([n for _,n in state1]) == sorted([n for _,n in state2])  
[docs]class RtnPredictor(Predictor):
    """Predictor for RTNs (recurrent transition networks). This 
    predictor assumes a directory structure as produced by HiFST. You 
    can use this predictor for non-deterministic lattices too. This
    implementation supports late expansion: RTNs are only expanded as
    far as necessary to retrieve all currently reachable states.
    
    ``cur_nodes`` contains the accumulated weights from the last 
    consumed word (if ambiguous, the largest)
    
    This implementation does not maintain a list of active nodes like 
    the other automata predictors. Instead, we store the current 
    history and search for the active nodes at each expansion. This is
    more expensive, but fstreplace might change state IDs so a list of
    active nodes might get corrupted.
    
    Note that this predictor does not support FSTs in gzip format.
    """
    
    def __init__(self,
                 rtn_path,
                 use_weights,
                 normalize_scores,
                 to_log = True,
                 minimize_rtns = False,
                 rmeps = True):
        """Creates a new RTN predictor.
        
        Args:
            rtn_path (string): Path to the RTN directory
            use_weights (bool): If false, replace all arc weights with
                                0 (=log 1).
            normalize_scores (bool): If true, we normalize the weights
                                     on all outgoing arcs such that
                                     they sum up to 1
            to_log (bool): SGNMT uses normal log probs (scores) while
                           arc weights in FSTs normally have cost (i.e.
                           neg. log values) semantics. Therefore, if
                           true, we multiply arc weights by -1.
            minimize_rtns (bool): Minimize the FST after each replace
                                  operation
            rmeps (bool): Remove epsilons in the FST after each replace
                          operation 
        """
        super(RtnPredictor, self).__init__()
        self.root_path = rtn_path
        self.minimize_rtns = minimize_rtns
        self.rmeps = rmeps
        self.use_weights = use_weights
        self.normalize_scores = normalize_scores
        self.weight_factor = -1.0 if to_log else 1.0
        self.cur_fst = None # current root fst
        start_id = '1'
        try:
            with open("%s/ntmap" % self.root_path) as f:
                ntmap = dict(line.strip().split(None, 1) for line in f)
                start_id = ntmap['S']
        except:
            logging.warn("Could not find NT S in ntmap. Assuming its ID 1")
        self.root_fst_prefix = "1%s000" % start_id.zfill(3)
        
[docs]    def get_unk_probability(self, posterior):
        """Always returns negative infinity: Words outside the 
        RTN are not possible according to this predictor.
        
        Returns:
            float. Negative infinity
        """
        return utils.NEG_INF 
    
[docs]    def initialize(self, src_sentence):
        """Loads the root RTN and consumes the start of sentence 
        symbol.
        
        Args:
            src_sentence (list):  Not used
        """
        try:
            file_name = "%s/%d.fst" % (self.root_path, self.current_sen_id+1)
            if not os.access(file_name, os.R_OK): # Find root FST
                search_pattern = '%s/%d/%s*.fst' % (self.root_path,
                                                    self.current_sen_id+1,
                                                    self.root_fst_prefix)
                candidates = glob.glob(search_pattern)
                if not candidates:
                    logging.error("Could not find root fst in %s" % 
                                    search_pattern)
                    self.cur_fst = None
                    return
                if len(candidates) > 1:
                    logging.warn("Ambiguous root fst for %s. Take the one "
                                 "with largest span." % search_pattern)
                    candidates = sorted(candidates)
                file_name = candidates[-1]
            self.cur_fst = fst.Fst.read(file_name) 
            logging.debug("Read (root)fst from %s" % file_name)
        except Exception as e:
            logging.error("%s error reading fst from %s: %s" %
                (sys.exc_info()[1], file_name, e))
            self.cur_fst = None
        finally:
            self.cur_history = []
            self.sub_fsts = {}
        self.consume(utils.GO_ID) 
    
[docs]    def expand_rtn(self, func):
        """This method expands the RTN as far as necessary. This means
        that the RTN is expanded s.t. we can build the posterior for 
        ``cur_history``. In practice, this means that we follow all 
        epsilon edges and replaces all NT edges until all paths with 
        the prefix ``cur_history`` in the RTN have at least one more 
        terminal token. Then, we apply ``func`` to all reachable nodes.
        """
        updated = True
        while updated:
            updated = False
            label_fst_map = {}
            self.visited_nodes = {}
            self.cur_fst.arcsort(sort_type="olabel")
            self.add_to_label_fst_map_recursive(label_fst_map,
                                                {},
                                                self.cur_fst.start(), 
                                                0.0,
                                                self.cur_history, func)
            if label_fst_map:
                logging.debug("Replace %d NT arcs for history %s" % (
                                                            len(label_fst_map),
                                                            self.cur_history))
                # First in the list is the root FST and label
                replaced_fst = fst.replace(
                        [(len(label_fst_map) + 2000000000, self.cur_fst)] 
                        + [(nt_label, f) 
                            for (nt_label, f) in label_fst_map.items()],
                        epsilon_on_replace=True)
                self.cur_fst = replaced_fst
                updated = True
        if self.rmeps or self.minimize_rtns:
            self.cur_fst.rmepsilon()
        if self.minimize_rtns:
            tmp = fst.determinize(self.cur_fst.determinize)
            self.cur_fst = tmp
            self.cur_fst.minimize() 
    
[docs]    def add_to_label_fst_map_recursive(self, 
                                       label_fst_map, 
                                       visited_nodes, 
                                       root_node, 
                                       acc_weight, 
                                       history, 
                                       func):
        """Adds arcs to ``label_fst_map`` if they are labeled with an
        NT symbol and reachable from ``root_node`` via ``history``.
          
        Note: visited_nodes is maintained for each history separately
        """
        if root_node in visited_nodes:
            # This introduces some error as we take the score of the first best
            # path with a certain history, not the globally best path. For now,
            # this error should not be significant
            return
        visited_nodes[root_node] = True
        for arc in self.cur_fst.arcs(root_node):
            arc_acc_weight = acc_weight + self.weight_factor*w2f(arc.weight)
            if arc.olabel == EPS_ID: # Follow epsilon edges
                self.add_to_label_fst_map_recursive(label_fst_map,
                                                    visited_nodes,
                                                    arc.nextstate,
                                                    arc_acc_weight, 
                                                    history,
                                                    func)
            elif not history:
                if self.is_nt_label(arc.olabel): # Add to label_fst_map
                    replace_label = len(label_fst_map) + 2000000000
                    label_fst_map[replace_label] = self.get_sub_fst(
                                                                    arc.olabel)
                    arc.ilabel = replace_label
                    arc.olabel = replace_label
                else: # This is a regular arc and we have no history left
                    func(arc.nextstate, arc.olabel, arc_acc_weight) # apply func
            elif arc.olabel == history[0]: # history is not empty
                self.add_to_label_fst_map_recursive(label_fst_map,
                                                    {},
                                                    arc.nextstate,
                                                    arc_acc_weight,
                                                    history[1:],
                                                    func)
            elif arc.olabel > history[0]: # FST is arc sorted, we can stop here
                break 
        
    
[docs]    def is_nt_label(self, label):
        """Returns true if ``label`` is a non-terminal. """
        s = str(label)
        return len(s) == 10 and s[0] == '1' 
[docs]    def get_sub_fst(self, fst_id):
        """ Load sub fst from the file system or the cache """
        if fst_id in self.sub_fsts:
            return self.sub_fsts[fst_id]
        sub_fst_path = "%s/%d/%d.fst" %  (self.root_path,
                                          self.current_sen_id+1,
                                          fst_id)
        try:
            sub_fst = fst.Fst.read(sub_fst_path)
            logging.debug("Read sub fst from %s" % sub_fst_path)
            self.sub_fsts[fst_id] = sub_fst
            return sub_fst
        except Exception as e:
            logging.error("%s error reading sub fst from %s: %s" %
                (sys.exc_info()[1], sub_fst_path, e)) 
        
    def _add_to_cur_posterior(self, node, label, weight):
        """Can be used as ``func`` argument in ``expand_rtn`` to build
        up the posterior for the next target token  in ``predict_next``
        """
        self.cur_posterior[label] = max(self.cur_posterior.get(label, utils.NEG_INF),
                                        weight)
    
[docs]    def predict_next(self):
        """Expands RTN as far as possible and uses the outgoing edges 
        from nodes reachable by the current history to build up
        the posterior for the next word. If there are no such nodes
        or arcs, or no root FST is loaded, return the empty set.
        """
        if not self.cur_fst:
            return {}
        self.cur_posterior = {}
        self.expand_rtn(self._add_to_cur_posterior)
        return self.finalize_posterior(self.cur_posterior,
                                       self.use_weights,
                                       self.normalize_scores) 
    
[docs]    def consume(self, word):
        """Adds ``word`` to the current history. """
        self.cur_history.append(word) 
    
[docs]    def get_state(self):
        """Returns the current history. """
        return self.cur_history 
    
[docs]    def set_state(self, state):
        """Sets the current history. """
        self.cur_history = state