Permalink
Please
sign in to comment.
Browse files
Support for OAuth based SSH Channels (#944)
* Updating globus_ssh requirements * Updating globus_ssh requirements * Adding support for an OAuth based SSH Channel * Minor cleanups, and revising nbytes to 1024 * Adding an oauth channel test * Replacing globus-ssh requirement with the oauth-ssh module * Updating version req for oauth-ssh to 0.9 * Fixing based on review comments from Ben * Adding new channel to init * Adding oauth-ssh to mypy exclude list * Removing unused modules * Addressing Ben's and Anna's comments
- Loading branch information...
Showing
with
158 additions
and 2 deletions.
@@ -1,5 +1,6 @@ | |||
from parsl.channels.ssh.ssh import SSHChannel | |||
from parsl.channels.local.local import LocalChannel | |||
from parsl.channels.ssh_il.ssh_il import SSHInteractiveLoginChannel | |||
from parsl.channels.oauth_ssh.oauth_ssh import OAuthSSHChannel | |||
|
|||
__all__ = ['SSHChannel', 'LocalChannel', 'SSHInteractiveLoginChannel'] | |||
__all__ = ['SSHChannel', 'LocalChannel', 'SSHInteractiveLoginChannel', 'OAuthSSHChannel'] |
No changes.
@@ -0,0 +1,137 @@ | |||
import logging | |||
import paramiko | |||
import socket | |||
|
|||
from parsl.errors import OptionalModuleMissing | |||
from parsl.channels.ssh.ssh import SSHChannel | |||
|
|||
try: | |||
from oauth_ssh.ssh_service import SSHService | |||
from oauth_ssh.oauth_ssh_token import find_access_token | |||
_oauth_ssh_enabled = True | |||
except (ImportError, NameError): | |||
_oauth_ssh_enabled = False | |||
|
|||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
|
|||
class OAuthSSHChannel(SSHChannel): | |||
"""SSH persistent channel. This enables remote execution on sites | |||
accessible via ssh. This channel uses Globus based OAuth tokens for authentication. | |||
""" | |||
|
|||
def __init__(self, hostname, username=None, script_dir=None, envs=None, port=22): | |||
''' Initialize a persistent connection to the remote system. | |||
We should know at this point whether ssh connectivity is possible | |||
Args: | |||
- hostname (String) : Hostname | |||
KWargs: | |||
- username (string) : Username on remote system | |||
- script_dir (string) : Full path to a script dir where | |||
generated scripts could be sent to. | |||
- envs (dict) : A dictionary of env variables to be set when executing commands | |||
- port (int) : Port at which the SSHService is running | |||
Raises: | |||
''' | |||
if not _oauth_ssh_enabled: | |||
raise OptionalModuleMissing(['oauth_ssh'], | |||
"OauthSSHChannel requires oauth_ssh module and config.") | |||
|
|||
self.hostname = hostname | |||
self.username = username | |||
self.script_dir = script_dir | |||
|
|||
self.envs = {} | |||
if envs is not None: | |||
self.envs = envs | |||
|
|||
try: | |||
access_token = find_access_token(hostname) | |||
except Exception: | |||
logger.exception("Failed to find the access token for {}".format(hostname)) | |||
raise | |||
|
|||
try: | |||
self.service = SSHService(hostname, port) | |||
self.transport = self.service.login(access_token, username) | |||
|
|||
except Exception: | |||
logger.exception("Caught an exception in the OAuth authentication step with {}".format(hostname)) | |||
raise | |||
|
|||
self.sftp_client = paramiko.SFTPClient.from_transport(self.transport) | |||
|
|||
def execute_wait(self, cmd, walltime=60, envs={}): | |||
''' Synchronously execute a commandline string on the shell. | |||
This command does *NOT* honor walltime currently. | |||
Args: | |||
- cmd (string) : Commandline string to execute | |||
Kwargs: | |||
- walltime (int) : walltime in seconds | |||
- envs (dict) : Dictionary of env variables | |||
Returns: | |||
- retcode : Return code from the execution, -1 on fail | |||
- stdout : stdout string | |||
- stderr : stderr string | |||
Raises: | |||
None. | |||
''' | |||
|
|||
session = self.transport.open_session() | |||
session.setblocking(0) | |||
|
|||
nbytes = 1024 | |||
session.exec_command(self.prepend_envs(cmd, envs)) | |||
session.settimeout(walltime) | |||
|
|||
try: | |||
# Wait until command is executed | |||
exit_status = session.recv_exit_status() | |||
|
|||
stdout = session.recv(nbytes).decode('utf-8') | |||
stderr = session.recv_stderr(nbytes).decode('utf-8') | |||
|
|||
except socket.timeout: | |||
logger.exception("Command failed to execute without timeout limit on {}".format(self)) | |||
raise | |||
|
|||
return exit_status, stdout, stderr | |||
|
|||
def execute_no_wait(self, cmd, walltime=60, envs={}): | |||
''' Execute asynchronousely without waiting for exitcode | |||
Args: | |||
- cmd (string): Commandline string to be executed on the remote side | |||
KWargs: | |||
- walltime (int): timeout to exec_command | |||
- envs (dict): A dictionary of env variables | |||
Returns: | |||
- None, stdout (readable stream), stderr (readable stream) | |||
Raises: | |||
- ChannelExecFailed (reason) | |||
''' | |||
session = self.transport.open_session() | |||
session.setblocking(0) | |||
|
|||
nbytes = 10240 | |||
session.exec_command(self.prepend_envs(cmd, envs)) | |||
|
|||
stdout = session.recv(nbytes).decode('utf-8') | |||
stderr = session.recv_stderr(nbytes).decode('utf-8') | |||
|
|||
return None, stdout, stderr | |||
|
|||
def close(self): | |||
return self.transport.close() |
@@ -0,0 +1,13 @@ | |||
from parsl.channels import OAuthSSHChannel | |||
|
|||
|
|||
def test_channel(): | |||
channel = OAuthSSHChannel(hostname='ssh.demo.globus.org', username='yadunand') | |||
x, stdout, stderr = channel.execute_wait('ls') | |||
print(x, stdout, stderr) | |||
assert x == 0, "Expected exit code 0, got {}".format(x) | |||
|
|||
|
|||
if __name__ == '__main__': | |||
|
|||
test_channel() |
0 comments on commit
0e7921d