Source code for tensorpack.callbacks.concurrency

# -*- coding: utf-8 -*-
# File:

import multiprocessing as mp

from ..utils import logger
from ..utils.concurrency import StoppableThread, start_proc_mask_signal
from .base import Callback

__all__ = ['StartProcOrThread']

[docs]class StartProcOrThread(Callback): """ Start some threads or processes before training. """ _chief_only = False
[docs] def __init__(self, startable, stop_at_last=True): """ Args: startable (list): list of processes or threads which have ``start()`` method. Can also be a single instance of process of thread. stop_at_last (bool): whether to stop the processes or threads after training. It will use :meth:`Process.terminate()` or :meth:`StoppableThread.stop()`, but will do nothing on normal ``threading.Thread`` or other startable objects. """ if not isinstance(startable, list): startable = [startable] self._procs_threads = startable self._stop_at_last = stop_at_last
def _before_train(self):"Starting " + ', '.join([ for k in self._procs_threads]) + ' ...') # avoid sigint get handled by other processes start_proc_mask_signal(self._procs_threads) def _after_train(self): if not self._stop_at_last: return for k in self._procs_threads: if not k.is_alive(): continue if isinstance(k, mp.Process):"Stopping {} ...".format( k.terminate() k.join(5.0) if k.is_alive(): logger.error("Cannot join process {}.".format( elif isinstance(k, StoppableThread):"Stopping {} ...".format( k.stop() k.join(5.0) if k.is_alive(): logger.error("Cannot join thread {}.".format(