Source code for cam.sgnmt.tf_utils

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains utility functions for TensorFlow such as
session handling and checkpoint loading.
"""

try:
    # This is the TF backend needed for MoE interpolation
    import tensorflow as tf
    from tensorflow.python.training import saver
    from tensorflow.python.training import training
except ImportError:
    pass # Deal with it in decode.py

import os
import logging


[docs]def session_config(n_cpu_threads=-1): """Creates the session config with default parameters. Args: n_cpu_threads (int): Number of CPU threads. If negative, we assume either GPU decoding or that all CPU cores can be used. Returns: A TF session config object. """ graph_options = tf.GraphOptions(optimizer_options=tf.OptimizerOptions( opt_level=tf.OptimizerOptions.L1, do_function_inlining=False)) if n_cpu_threads < 0: gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=0.95) config = tf.ConfigProto( allow_soft_placement=True, graph_options=graph_options, gpu_options=gpu_options, log_device_placement=False) else: #device_count={'CPU': n_cpu_threads}, if n_cpu_threads >= 4: # This adjustment is an estimate of the effective load which # accounts for the sequential parts in SGNMT. if n_cpu_threads == 4: n_cpu_threads = 5 else: n_cpu_threads = int(n_cpu_threads*5/1.5 - 10) logging.debug("Setting TF inter and intra op parallelism " "to %d" % n_cpu_threads) config = tf.ConfigProto( intra_op_parallelism_threads=n_cpu_threads, inter_op_parallelism_threads=n_cpu_threads, allow_soft_placement=True, graph_options=graph_options, log_device_placement=False) return config
[docs]def create_session(checkpoint_path, n_cpu_threads=-1): """Creates a MonitoredSession. Args: checkpoint_path (string): Path either to checkpoint directory or directly to a checkpoint file. n_cpu_threads (int): Number of CPU threads. If negative, we assume either GPU decoding or that all CPU cores can be used. Returns: A TensorFlow MonitoredSession. """ try: if os.path.isdir(checkpoint_path): checkpoint_path = saver.latest_checkpoint(checkpoint_path) else: logging.info("%s is not a directory. Interpreting as direct " "path to checkpoint..." % checkpoint_path) return training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, config=session_config(n_cpu_threads))) except tf.errors.NotFoundError as e: logging.fatal("Could not find all variables of the computation " "graph in the T2T checkpoint file. This means that the " "checkpoint does not correspond to the model specified in " "SGNMT. Please double-check pred_src_vocab_size, " "pred_trg_vocab_size, and all the t2t_* parameters. " "Also make sure that the checkpoint exists and is readable") raise AttributeError("Could not initialize TF session.")