Commit 15054114 authored by Stefano Alberto Russo's avatar Stefano Alberto Russo
Browse files

Fixed wrong tunnel setup for ssh-based computing resources.

parent 4fa285b7
Loading
Loading
Loading
Loading
+63 −57
Original line number Diff line number Diff line
@@ -32,9 +32,6 @@ color_map = ["#440154", "#440558", "#450a5c", "#450e60", "#451465", "#461969",
             "#97d73e", "#9ed93a", "#a8db34", "#b0dd31", "#b8de30", "#c3df2e",
             "#cbe02d", "#d6e22b", "#e1e329", "#eae428", "#f5e626", "#fde725"]

#======================
#  Utility functions
#======================

def booleanize(*args, **kwargs):
    # Handle both single value and kwargs to get arg name
@@ -265,10 +262,6 @@ def get_md5(string):
    return md5


#=========================
#   Time 
#=========================

def timezonize(timezone):
    '''Convert a string representation of a timezone to its pytz object or do nothing if the argument is already a pytz timezone'''
    
@@ -283,14 +276,17 @@ def timezonize(timezone):
        timezone = pytz.timezone(timezone)
    return timezone


def now_t():
    '''Return the current time in epoch seconds'''
    return now_s()


def now_s():
    '''Return the current time in epoch seconds'''
    return calendar.timegm(now_dt().utctimetuple())


def now_dt(tzinfo='UTC'):
    '''Return the current time in datetime format'''
    if tzinfo != 'UTC':
@@ -335,10 +331,12 @@ def dt(*args, **kwargs):

    return  time_dt


def get_tz_offset_s(time_dt):
    '''Get the time zone offset in seconds'''
    return s_from_dt(time_dt.replace(tzinfo=pytz.UTC)) - s_from_dt(time_dt)


def check_dt_consistency(date_dt):
    '''Check that the timezone is consistent with the datetime (some conditions in Python lead to have summertime set in winter)'''

@@ -355,6 +353,7 @@ def check_dt_consistency(date_dt):
        else:
            return True


def correct_dt_dst(datetime_obj):
    '''Check that the dst is correct and if not change it'''

@@ -374,14 +373,17 @@ def correct_dt_dst(datetime_obj):
              datetime_obj.microsecond,
              tzinfo=datetime_obj.tzinfo)


def change_tz(dt, tz):
    return dt.astimezone(timezonize(tz))


def dt_from_t(timestamp_s, tz=None):
    '''Create a datetime object from an epoch timestamp in seconds. If no timezone is given, UTC is assumed'''
    # TODO: check if uniform everything on this one or not.
    return dt_from_s(timestamp_s=timestamp_s, tz=tz)


def dt_from_s(timestamp_s, tz=None):
    '''Create a datetime object from an epoch timestamp in seconds. If no timezone is given, UTC is assumed'''

@@ -397,6 +399,7 @@ def dt_from_s(timestamp_s, tz=None):
    
    return timestamp_dt


def s_from_dt(dt):
    '''Returns seconds with floating point for milliseconds/microseconds.'''
    if not (isinstance(dt, datetime.datetime)):
@@ -404,6 +407,7 @@ def s_from_dt(dt):
    microseconds_part = (dt.microsecond/1000000.0) if dt.microsecond else 0
    return  ( calendar.timegm(dt.utctimetuple()) + microseconds_part)


def dt_from_str(string, timezone=None):

    # Supported formats on UTC
@@ -458,10 +462,12 @@ def dt_from_str(string, timezone=None):
    
    return dt(year, month, day, hour, minute, second, usecond, offset_s=offset_s)


def dt_to_str(dt):
    '''Return the ISO representation of the datetime as argument'''
    return dt.isoformat()


class dt_range(object):

    def __init__(self, from_dt, to_dt, timeSlotSpan):
@@ -489,20 +495,18 @@ class dt_range(object):
        return self.__next__()


#================================
#  Others
#================================

def debug_param(**kwargs):
    for item in kwargs:
        logger.critical('Param "{}": "{}"'.format(item, kwargs[item]))


def get_my_ip():
    import socket
    hostname = socket.gethostname()
    my_ip = socket.gethostbyname(hostname)
    return my_ip


def get_webapp_conn_string():
    webapp_ssl  = booleanize(os.environ.get('ROSETTA_WEBAPP_SSL', False))
    webapp_host = os.environ.get('ROSETTA_WEBAPP_HOST', get_my_ip())
@@ -513,32 +517,68 @@ def get_webapp_conn_string():
        webapp_conn_string = 'http://{}:{}'.format(webapp_host, webapp_port)
    return webapp_conn_string


def get_platform_registry():
    platform_registry_host = os.environ.get('PLATFORM_REGISTRY_HOST', 'proxy')
    platform_registry_port = os.environ.get('PLATFORM_REGISTRY_PORT', '5000')
    platform_registry_conn_string = '{}:{}'.format(platform_registry_host, platform_registry_port)
    return platform_registry_conn_string

  
def get_rosetta_tasks_tunnel_host():
    # Importing here instead of on top avoids circular dependencies problems when loading booleanize in settings
    from django.conf import settings
    tunnel_host = os.environ.get('ROSETTA_TASKS_TUNNEL_HOST', settings.ROSETTA_HOST)
    return tunnel_host


def get_rosetta_tasks_proxy_host():
    # Importing here instead of on top avoids circular dependencies problems when loading booleanize in settings
    from django.conf import settings
    proxy_host = os.environ.get('ROSETTA_TASKS_PROXY_HOST', settings.ROSETTA_HOST)
    return proxy_host


def hash_string_to_int(string):
    return int(hashlib.sha1(string.encode('utf8')).hexdigest(), 16)


def get_ssh_access_mode_credentials(computing, user):
    
    from .models import KeyPair
    
    # Get computing host
    try:
        computing_host = computing.conf.get('host')
    except AttributeError:
        computing_host = None
    if not computing_host:
        raise ValueError('No computing host?!')

    # Get computing (SSH) port
    try:
        computing_port = computing.conf.get('port')
    except AttributeError:
        computing_port = 22
    if not computing_host:
        computing_port = 22
      
    # Get computing user and keys
    if computing.auth_mode == 'user_keys':
        computing_user = user.profile.get_extra_conf('computing_user', computing)
        if not computing_user:
            raise ValueError('No \'computing_user\' parameter found for computing resource \'{}\' in user profile'.format(computing.name))
        # Get user key
        computing_keys = KeyPair.objects.get(user=user, default=True)
    elif computing.auth_mode == 'platform_keys':        
        computing_user = computing.conf.get('user')
        computing_keys = KeyPair.objects.get(user=None, default=True)
    else:
        raise NotImplementedError('Auth modes other than user_keys and platform_keys not supported.')
    if not computing_user:
            raise ValueError('No \'user\' parameter found for computing resource \'{}\' in its configuration'.format(computing.name))
    return (computing_user, computing_host, computing_port, computing_keys)

#================================
#  Tunnel (and proxy) setup
#================================

def setup_tunnel_and_proxy(task):

@@ -601,6 +641,12 @@ def setup_tunnel_and_proxy(task):
                     
            tunnel_command= 'ssh -4 -i {} -o StrictHostKeyChecking=no -nNT -L 0.0.0.0:{}:{}:{} {}@{} & '.format(user_keys.private_key_file, task.tcp_tunnel_port, task.interface_ip, task.interface_port, first_user, first_host)

        else:
            
            if task.computing.access_mode.startswith('ssh'):
                computing_user, computing_host, computing_port, computing_keys = get_ssh_access_mode_credentials(task.computing, task.user)
                tunnel_command  = 'ssh -p {} -o LogLevel=ERROR -i {} -4 -o StrictHostKeyChecking=no -o ConnectTimeout=10 '.format(computing_port, computing_keys.private_key_file)
                tunnel_command += '-nNT -L 0.0.0.0:{}:{}:{} {}@{}'.format(task.tcp_tunnel_port, task.interface_ip, task.interface_port, computing_user, computing_host)
            else:
                tunnel_command= 'ssh -4 -o StrictHostKeyChecking=no -nNT -L 0.0.0.0:{}:{}:{} localhost & '.format(task.tcp_tunnel_port, task.interface_ip, task.interface_port)
        
@@ -713,46 +759,6 @@ Listen '''+str(task.tcp_tunnel_port)+'''
                raise ErrorMessage('Something went wrong when loading the task proxy conf')        
            



def get_ssh_access_mode_credentials(computing, user):
    
    from .models import KeyPair
    
    # Get computing host
    try:
        computing_host = computing.conf.get('host')
    except AttributeError:
        computing_host = None
    if not computing_host:
        raise ValueError('No computing host?!')

    # Get computing (SSH) port
    try:
        computing_port = computing.conf.get('port')
    except AttributeError:
        computing_port = 22
    if not computing_host:
        computing_port = 22
      
    # Get computing user and keys
    if computing.auth_mode == 'user_keys':
        computing_user = user.profile.get_extra_conf('computing_user', computing)
        if not computing_user:
            raise ValueError('No \'computing_user\' parameter found for computing resource \'{}\' in user profile'.format(computing.name))
        # Get user key
        computing_keys = KeyPair.objects.get(user=user, default=True)
    elif computing.auth_mode == 'platform_keys':        
        computing_user = computing.conf.get('user')
        computing_keys = KeyPair.objects.get(user=None, default=True)
    else:
        raise NotImplementedError('Auth modes other than user_keys and platform_keys not supported.')
    if not computing_user:
            raise ValueError('No \'user\' parameter found for computing resource \'{}\' in its configuration'.format(computing.name))
    return (computing_user, computing_host, computing_port, computing_keys)



def sanitize_container_env_vars(env_vars):
    
    for env_var in env_vars: