from __future__ import print_function, division, absolute_import
import logging
import socket
import os
import sys
import time
import traceback
try:
from queue import Queue
except ImportError: # Python 2.7 fix
from Queue import Queue
from threading import Thread
from toolz import merge
from tornado import gen
logger = logging.getLogger(__name__)
# These are handy for creating colorful terminal output to enhance readability
# of the output generated by dask-ssh.
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def async_ssh(cmd_dict):
import paramiko
from paramiko.buffered_pipe import PipeTimeout
from paramiko.ssh_exception import (SSHException, PasswordRequiredException)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
retries = 0
while True: # Be robust to transient SSH failures.
try:
# Set paramiko logging to WARN or higher to squelch INFO messages.
logging.getLogger('paramiko').setLevel(logging.WARN)
ssh.connect(hostname=cmd_dict['address'],
username=cmd_dict['ssh_username'],
port=cmd_dict['ssh_port'],
key_filename=cmd_dict['ssh_private_key'],
compress=True,
timeout=20,
banner_timeout=20) # Helps prevent timeouts when many concurrent ssh connections are opened.
# Connection successful, break out of while loop
break
except (SSHException,
PasswordRequiredException) as e:
print('[ dask-ssh ] : ' + bcolors.FAIL +
'SSH connection error when connecting to {addr}:{port}'
'to run \'{cmd}\''.format(addr=cmd_dict['address'],
port=cmd_dict['ssh_port'],
cmd=cmd_dict['cmd']) + bcolors.ENDC)
print(bcolors.FAIL + ' SSH reported this exception: ' + str(e) + bcolors.ENDC)
# Print an exception traceback
traceback.print_exc()
# Transient SSH errors can occur when many SSH connections are
# simultaneously opened to the same server. This makes a few
# attempts to retry.
retries += 1
if retries >= 3:
print('[ dask-ssh ] : '
+ bcolors.FAIL
+ 'SSH connection failed after 3 retries. Exiting.'
+ bcolors.ENDC)
# Connection failed after multiple attempts. Terminate this thread.
os._exit(1)
# Wait a moment before retrying
print(' ' + bcolors.FAIL +
'Retrying... (attempt {n}/{total})'.format(n=retries, total=3) +
bcolors.ENDC)
time.sleep(1)
# Execute the command, and grab file handles for stdout and stderr. Note
# that we run the command using the user's default shell, but force it to
# run in an interactive login shell, which hopefully ensures that all of the
# user's normal environment variables (via the dot files) have been loaded
# before the command is run. This should help to ensure that important
# aspects of the environment like PATH and PYTHONPATH are configured.
print('[ {label} ] : {cmd}'.format(label=cmd_dict['label'],
cmd=cmd_dict['cmd']))
stdin, stdout, stderr = ssh.exec_command('$SHELL -i -c \'' + cmd_dict['cmd'] + '\'', get_pty=True)
# Set up channel timeout (which we rely on below to make readline() non-blocking)
channel = stdout.channel
channel.settimeout(0.1)
def read_from_stdout():
"""
Read stdout stream, time out if necessary.
"""
try:
line = stdout.readline()
while len(line) > 0: # Loops until a timeout exception occurs
line = line.rstrip()
logger.debug('stdout from ssh channel: %s', line)
cmd_dict['output_queue'].put('[ {label} ] : {output}'.format(label=cmd_dict['label'],
output=line))
line = stdout.readline()
except (PipeTimeout, socket.timeout):
pass
def read_from_stderr():
"""
Read stderr stream, time out if necessary.
"""
try:
line = stderr.readline()
while len(line) > 0:
line = line.rstrip()
logger.debug('stderr from ssh channel: %s', line)
cmd_dict['output_queue'].put('[ {label} ] : '.format(label=cmd_dict['label']) +
bcolors.FAIL + '{output}'.format(output=line) + bcolors.ENDC)
line = stderr.readline()
except (PipeTimeout, socket.timeout):
pass
def communicate():
"""
Communicate a little bit, without blocking too long.
Return True if the command ended.
"""
read_from_stdout()
read_from_stderr()
# Check to see if the process has exited. If it has, we let this thread
# terminate.
if channel.exit_status_ready():
exit_status = channel.recv_exit_status()
cmd_dict['output_queue'].put('[ {label} ] : '.format(label=cmd_dict['label']) +
bcolors.FAIL +
"remote process exited with exit status " +
str(exit_status) + bcolors.ENDC)
return True
# Wait for a message on the input_queue. Any message received signals this
# thread to shut itself down.
while cmd_dict['input_queue'].empty():
# Kill some time so that this thread does not hog the CPU.
time.sleep(1.0)
if communicate():
break
# Ctrl-C the executing command and wait a bit for command to end cleanly
start = time.time()
while time.time() < start + 5.0:
channel.send(b'\x03') # Ctrl-C
if communicate():
break
time.sleep(1.0)
# Shutdown the channel, and close the SSH connection
channel.close()
ssh.close()
def start_scheduler(logdir, addr, port, ssh_username, ssh_port, ssh_private_key, remote_python=None):
cmd = '{python} -m distributed.cli.dask_scheduler --port {port}'.format(
python=remote_python or sys.executable, port=port, logdir=logdir)
# Optionally re-direct stdout and stderr to a logfile
if logdir is not None:
cmd = 'mkdir -p {logdir} && '.format(logdir=logdir) + cmd
cmd += '&> {logdir}/dask_scheduler_{addr}:{port}.log'.format(addr=addr,
port=port, logdir=logdir)
# Format output labels we can prepend to each line of output, and create
# a 'status' key to keep track of jobs that terminate prematurely.
label = (bcolors.BOLD +
'scheduler {addr}:{port}'.format(addr=addr, port=port) +
bcolors.ENDC)
# Create a command dictionary, which contains everything we need to run and
# interact with this command.
input_queue = Queue()
output_queue = Queue()
cmd_dict = {'cmd': cmd, 'label': label, 'address': addr, 'port': port,
'input_queue': input_queue, 'output_queue': output_queue,
'ssh_username': ssh_username, 'ssh_port': ssh_port,
'ssh_private_key': ssh_private_key}
# Start the thread
thread = Thread(target=async_ssh, args=[cmd_dict])
thread.daemon = True
thread.start()
return merge(cmd_dict, {'thread': thread})
def start_worker(logdir, scheduler_addr, scheduler_port, worker_addr, nthreads, nprocs,
ssh_username, ssh_port, ssh_private_key, nohost,
memory_limit,
worker_port,
nanny_port,
remote_python=None):
cmd = ('{python} -m distributed.cli.dask_worker '
'{scheduler_addr}:{scheduler_port} '
'--nthreads {nthreads} --nprocs {nprocs} ')
if not nohost:
cmd += ' --host {worker_addr} '
if memory_limit:
cmd += '--memory-limit {memory_limit} '
if worker_port:
cmd += '--worker-port {worker_port} '
if nanny_port:
cmd += '--nanny-port {nanny_port} '
cmd = cmd.format(
python=remote_python or sys.executable,
scheduler_addr=scheduler_addr,
scheduler_port=scheduler_port,
worker_addr=worker_addr,
nthreads=nthreads,
nprocs=nprocs,
memory_limit=memory_limit,
worker_port=worker_port,
nanny_port=nanny_port)
# Optionally redirect stdout and stderr to a logfile
if logdir is not None:
cmd = 'mkdir -p {logdir} && '.format(logdir=logdir) + cmd
cmd += '&> {logdir}/dask_scheduler_{addr}.log'.format(
addr=worker_addr, logdir=logdir)
label = 'worker {addr}'.format(addr=worker_addr)
# Create a command dictionary, which contains everything we need to run and
# interact with this command.
input_queue = Queue()
output_queue = Queue()
cmd_dict = {'cmd': cmd, 'label': label, 'address': worker_addr,
'input_queue': input_queue, 'output_queue': output_queue,
'ssh_username': ssh_username, 'ssh_port': ssh_port,
'ssh_private_key': ssh_private_key}
# Start the thread
thread = Thread(target=async_ssh, args=[cmd_dict])
thread.daemon = True
thread.start()
return merge(cmd_dict, {'thread': thread})
[docs]class SSHCluster(object):
def __init__(self, scheduler_addr, scheduler_port, worker_addrs, nthreads=0, nprocs=1,
ssh_username=None, ssh_port=22, ssh_private_key=None,
nohost=False, logdir=None, remote_python=None,
memory_limit=None, worker_port=None, nanny_port=None):
self.scheduler_addr = scheduler_addr
self.scheduler_port = scheduler_port
self.nthreads = nthreads
self.nprocs = nprocs
self.ssh_username = ssh_username
self.ssh_port = ssh_port
self.ssh_private_key = ssh_private_key
self.nohost = nohost
self.remote_python = remote_python
self.memory_limit = memory_limit
self.worker_port = worker_port
self.nanny_port = nanny_port
# Generate a universal timestamp to use for log files
import datetime
if logdir is not None:
logdir = os.path.join(logdir, "dask-ssh_" + datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"))
print(bcolors.WARNING + 'Output will be redirected to logfiles '
'stored locally on individual worker nodes under "{logdir}".'.format(logdir=logdir)
+ bcolors.ENDC)
self.logdir = logdir
# Keep track of all running threads
self.threads = []
# Start the scheduler node
self.scheduler = start_scheduler(logdir, scheduler_addr,
scheduler_port, ssh_username, ssh_port,
ssh_private_key, remote_python)
# Start worker nodes
self.workers = []
for i, addr in enumerate(worker_addrs):
self.add_worker(addr)
@gen.coroutine
def _start(self):
pass
@property
def scheduler_address(self):
return '%s:%d' % (self.scheduler_addr, self.scheduler_port)
def monitor_remote_processes(self):
# Form a list containing all processes, since we treat them equally from here on out.
all_processes = [self.scheduler] + self.workers
try:
while True:
for process in all_processes:
while not process['output_queue'].empty():
print(process['output_queue'].get())
# Kill some time and free up CPU before starting the next sweep
# through the processes.
time.sleep(0.1)
# end while true
except KeyboardInterrupt:
pass # Return execution to the calling process
def add_worker(self, address):
self.workers.append(start_worker(self.logdir, self.scheduler_addr,
self.scheduler_port, address,
self.nthreads, self.nprocs,
self.ssh_username, self.ssh_port,
self.ssh_private_key, self.nohost,
self.memory_limit,
self.worker_port,
self.nanny_port,
self.remote_python))
def shutdown(self):
all_processes = [self.scheduler] + self.workers
for process in all_processes:
process['input_queue'].put('shutdown')
process['thread'].join()
def __enter__(self):
return self
def __exit__(self, *args):
self.shutdown()