#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

from winswitch.util.simple_logger import Logger, msig
logger = Logger("server_link")
debug_import = logger.get_debug_import()

debug_import("sys, time")
import sys
import time


debug_import("server_receiver")
from winswitch.client.server_receiver import ServerReceiverConnectionFactory
debug_import("fake_embedded_connector")
from winswitch.client.fake_embedded_connector import EmbeddedServerChannelFactory
debug_import("conch_server_channel")
from winswitch.client.conch_server_channel import ConchFixedPortLinkFactory, ConchStdioRemoteClientLinkFactory
debug_import("consts")
from winswitch.consts import LOCALHOST, X11_TYPE, GSTVIDEO_TYPE, SCREEN_TYPE, NOTIFY_ERROR
debug_import("globals")
from winswitch.globals import USER_ID, WIN32
debug_import("net_util")
from winswitch.net.net_util import get_port_mapper
debug_import("global_settings")
from winswitch.objects.global_settings import get_settings
debug_import("session")
from winswitch.objects.session import Session
debug_import("server_config")
from winswitch.objects.server_config import ServerConfig
debug_import("file_io")
from winswitch.util.file_io import get_local_server_socket
debug_import("main_loop")
from winswitch.util.main_loop import connectUNIX, connectTCP, callLater




MAX_RETRY = 3

TEST_TUNNELLED_SOCKET = False		#unused
SEND_MOUNT_POINTS_TO_LOCAL = False	#send mount points to a local server (which could end up mounting itself!)
FORCE_SEND_MOUNT_POINTS = "--force-send-mountpoints" in sys.argv		#useful for testing only

#useful for testing
IGNORE_REMOTE_ICONS = False
IGNORE_LOCAL_ICONS = False

port_mapper = get_port_mapper()


class KeepPort:
	"""
	Calls get_free_port once and keeps the result.
	"""
	def __init__(self, get_free_port):
		self.get_free_port = get_free_port
		self.port = 0
	def get_port(self):
		assert self.port == 0			#ensure we only call this method once
		self.port = self.get_free_port()
		return	self.port
	def get_assigned_port(self):
		assert self.port != 0			#ensure we have called get_port() already
		return	self.port



class ServerLink:
	"""
	This class provides methods for starting/maintaining/stopping the connection to a server.
	"""

	def __init__(self, server_config, notify=None, dialog_util=None, session_detach=None, session_attach=None):
		"""
		The caller must provide the configuration objects, and may provide notification_util and question_handler.
		This is to ensure that this class does not depend on any UI stuff.
		"""
		Logger(self)
		sig = msig(server_config, notify, dialog_util, session_detach, session_attach)
		self.server = server_config
		self.notify_function = notify
		self.dialog_util = dialog_util
		self.session_detach = session_detach
		self.session_attach = session_attach

		self.settings = get_settings()

		self.client_factory = None				#factory used for making connections to the server: a ServerReceiverConnectionFactory or ConchLinkFactory
		self.client = None						#ServerLineConnection instance
		self.conch_connection = None			#used by conch either for the client channel and tunnels or just for the tunnels
		self.forwarders = {}

		self.stop_requested = False
		self.warned = False
		self.notify_embargo = None
		self.debug(sig+" auto_resume=%s" % server_config.auto_resume)

	def __str__(self):
		return	"ServerLink(%s)" % str(self.server)

	def notify(self, title, message, delay=None, callback=None, notification_type=None, from_uuid=None):
		"""
		Uses the notification_util (if present) to notify the user about events.
		We disable this feature when stop_requested to prevent stop events from generating notifications.
		"""
		self.sdebug(None, title, message, delay, callback, notification_type, from_uuid)
		if not self.stop_requested and self.notify_function and self.server:
			if self.notify_embargo and self.notify_embargo>time.time():
				self.slog("not sending notification as embargo time not reached", title, message)
			else:
				self.notify_function(title, message, delay=delay, callback=callback, notification_type=notification_type, from_server=self.server, from_uuid=from_uuid)

	def ask(self, title, text, nok_callback, ok_callback, password=False, ask_save_password=False, buttons=None, icon=None, UUID=None):
		"""
		Uses the question_handler (if present) to ask the user questions.
		"""
		if self.dialog_util:
			self.dialog_util.ask(title, text, nok_callback, ok_callback, password=password, ask_save_password=ask_save_password, buttons=buttons, icon=icon, UUID=UUID)
		else:
			self.serror("question handler is not set!", title, text, ok_callback, nok_callback)

	def cancel_ask(self, UUID):
		self.sdebug(None, UUID)
		if self.dialog_util:
			self.dialog_util.cancel_ask(UUID)

	def stop(self):
		self.debug()
		self.send_logout()
		self.stop_requested = True
		self.close_connections()
		self.server.set_status(ServerConfig.STATUS_DISCONNECTED)
		self.server.touch()
		self.warned = False


	def send_logout(self):
		self.sdebug("client_factory=%s, client=%s" % (self.client_factory, self.client))
		if self.client_factory:
			self.client_factory.do_retry = False
			self.client_factory.closed = True
		if self.client and self.client.handler:
			try:
				#FIXME: this doesn't get flushed out! (why?)
				self.client.handler.send_logout()
			except Exception, e:
				self.error("() %s" % e)

	def kick(self):
		"""
		Connect or re-connect to the server
		"""
		self.stop_requested = False
		if self.client_factory and not self.server.is_connected():
			self.sdebug("aborting connection to server %s" % self.server)
			self.client_factory.abort = True
			self.client_factory = None
		if self.server.is_connected():
			self.sdebug("already connected to server %s" % self.server)
		else:
			self.sdebug("server=%s" % self.server)
			if self.client:
				self.stop_connection()
			self.connect()

	def set_client(self, client_instance):
		if self.client:
			self.serror("client already exists! closing connections!", client_instance)
			self.close_connections()
		else:
			self.slog(None, client_instance)
		self.client = client_instance

	def set_server_status(self, client_factory, new_status):
		if self.client_factory!=self.client_factory:
			self.slog("ignoring update from old client instance (current client_factory=%s)" % self.client_factory, client_factory, new_status)
			return
		self.server.set_status(new_status)

	def connect(self):
		try:
			self.connect_to_server()
		except Exception, e:
			self.exc(e)

	def connect_to_server(self):
		if self.server.is_connected():
			self.slog("already connected")
			return
		if self.server.is_connecting():
			self.slog("already connecting")
			return
		self.sdebug()
		self.server.set_status(ServerConfig.STATUS_CONNECTING)
		if self.server.embedded_server is not None:
			""" connect via fake connection to the same process """
			self.slog("connecting to embedded server: %s" % self.server.embedded_server)
			self.client_factory = EmbeddedServerChannelFactory(self)
			self.client_factory.connect()
		elif self.server.local and "--no-local-socket" not in sys.argv and not WIN32:
			""" connect to local socket (but not on win32!) """
			socket_filename = get_local_server_socket(USER_ID==0)
			self.client_factory = ServerReceiverConnectionFactory(self, "unix:%s" % socket_filename)
			self.sdebug("connecting to UNIX socket: %s" % socket_filename)
			connectUNIX(socket_filename, self.client_factory, timeout=self.server.timeout)
			callLater(self.server.timeout, self.check_connection_timeout)
		elif self.server.ssh_tunnel and not self.server.local:
			if self.server.command_port_auto or self.server.command_port<=0:
				""" use stdio client: """
				self.client_factory = ConchStdioRemoteClientLinkFactory(self)
			else:
				""" go straight to fixed port """
				self.client_factory = ConchFixedPortLinkFactory(self)
			self.sdebug("connecting using %s via SSH on %s:%s" % (self.client_factory, self.server.host, self.server.port))
			connectTCP(self.server.host, self.server.port, self.client_factory, timeout=self.server.timeout)
		else:
			""" connect over tcp - no tunnel """
			self.client_factory = ServerReceiverConnectionFactory(self, "tcp %s:%s" % (self.server.command_host, self.server.command_port))
			self.sdebug("default command_port=%s, command_host=%s command_port=%s, server.host=%s, server.port=%s" % (self.server.default_command_port, self.server.command_host, self.server.command_port, self.server.host, self.server.port))
			host = self.server.command_host
			port = self.server.command_port
			if not port:
				port = self.server.default_command_port
			if not host:
				host = self.server.host
			if port<=0:
				port = self.server.port
			self.sdebug("connecting using %s via TCP on %s:%s" % (self.client_factory, host, port))
			if not host or port<=0:
				self.notify("Cannot connect to %s" % self.server.get_display_name(),
							"The connection settings for this server are missing!",
							notification_type=NOTIFY_ERROR)
				self.close_connections()
				return
			connectTCP(host, port, self.client_factory, timeout=self.server.timeout)
			callLater(self.server.timeout, self.check_connection_timeout)

	def check_connection_timeout(self):
		self.sdebug("server.timeout=%s" % self.server.timeout)
		#should have reached connected state by now...
		if self.server.is_connected():
			self.sdebug("connection to %s established within %d seconds" % (self.server, self.server.timeout))
		elif not self.server.is_connecting():
			self.sdebug("connection to %s failed: waited %d seconds" % (self.server, self.server.timeout))
		else:
			self.serror("resetting state of client factory for connection to %s" % self.server)
			if self.client_factory:
				self.client_factory.do_retry = False
				self.client_factory.closed = True

	def close_connections(self):
		self.sdebug()
		self.server.set_status(ServerConfig.STATUS_DISCONNECTED)
		self.stop_forwarders()
		self.stop_connection()

	def stop_connection(self):
		if self.client_factory:
			self.client_factory.closed = True
			self.client_factory = None
		if self.client:
			try:
				self.client.stop(retry=False, message="stopping this connection")
				self.client = None
			except Exception, e:
				try:
					self.serr("error calling %s on %s (%s)" % (self.client.stop, self.client, type(self.client)), e)
				except Exception, e:
					self.sexc(e)

	def stop_forwarders(self):
		""" Stops all the port forwards """
		self.sdebug("forwarders %s" % str(self.forwarders))
		for forwarder in self.forwarders.values():
			try:
				self.sdebug("stopping forwarder: %s" % str(forwarder))
				#TODO: forwarder.stop()
			except Exception, e:
				self.exc(e)
		self.forwarders = {}



	def get_target_spec(self, from_port, remote_host, remote_port, reverse):
		return	"%s:%d#%s;" % (remote_host, remote_port,reverse)

	def get_existing_port_forward(self, host_spec):
		forward = self.forwarders.get(host_spec)
		"""if forward and forward.terminated:
			del self.forwarders[host_spec]
			return	None
		if forward and forward.stopping:
			return	None"""
		return forward

	def add_session_port_forward(self, session, from_port, remote_port, reverse=False):
		remote_host = session.host
		#FIXME: IPv6
		if remote_host=="" or remote_host=="0.0.0.0":
			remote_host = "127.0.0.1"		#cant use 0.0.0.0 as a target!
		fwd = self.add_port_forward(from_port, remote_host, remote_port, reverse)
		if fwd not in session.tunnels:
			session.tunnels.append(fwd)
		return	fwd

	def add_port_forward(self, from_port, remote_host, remote_port, reverse=False):
		"""
		from_port may be None in which case it will be assigned and returned
		Returns the actual from_port used.
		"""
		if self.conch_connection==None:
			self.serror("no conch connection - cannot forward ports!", from_port, remote_host, remote_port)
			return
		host_spec = self.get_target_spec(from_port, remote_host, remote_port, reverse)
		forwarder = self.get_existing_port_forward(host_spec)
		if forwarder:
			self.sdebug("found existing forward: %s" % str(forwarder), from_port, remote_host, remote_port)
		else:
			self.slog("adding new port forward", from_port, remote_host, remote_port)
			if not from_port or from_port==-1:
				from_port = port_mapper.get_free_command_port()
			def forwarding_ready(*args):
				self.slog(None, *args)
			def forwarding_failed(*args):
				self.serror(None, *args)
				err = ""
				try:
					msg,_ = args[0]
					err = "\nThe error message was: '%s'." % msg
				except:
					pass
				self.notify("SSH Port Forward Failed",
						"Please ensure that the server %s allows SSH port forwarding.%s" % (self.server.get_display_name(), err),
						notification_type=NOTIFY_ERROR)
			forwarder = self.conch_connection.forward_port(from_port, remote_host, remote_port, forwarding_ready, forwarding_failed)
			self.forwarders[host_spec] = forwarder
		self.sdebug("forwarder=%s" % str(forwarder), from_port, remote_host, remote_port)
		return forwarder

	def reverse_port_forward(self, from_port, remote_host, remote_port):
		"""
		Returns a deferred.
		"""
		sig = msig(from_port, remote_host, remote_port)
		if self.conch_connection==None:
			self.debug(sig+" no conch connection")
			pass
		host_spec = self.get_target_spec(from_port, remote_host, remote_port, True)
		forwarder = self.get_existing_port_forward(host_spec)
		if forwarder:
			pass
		self.debug(sig+" adding new reverse port forward")
		if not from_port:
			from_port = port_mapper.get_free_command_port()
		return	self.conch_connection.reverse_forward_port(from_port, remote_host, remote_port)



	def resume_sessions(self):
		self.log()
		for session in self.server.get_sessions().values():
			self.sdebug("testing %s: %s" % (session, session.status))
			if session.status == Session.STATUS_IDLE or session.status == Session.STATUS_AVAILABLE:
				self.connect_to_session(session, retry=True)

	#starts the tunnel (if needed) and connect to session (if not a preload one)
	def connect_to_session(self, session, retry=False):
		self.sdebug("requesting access", session, retry)
		self.prepare_session_ports(session, False)
		if self.client and self.client.handler:
			""" the response to this request (with the password) will trigger the actual connection """
			self.client.handler.send_request_session(session.ID)
		else:
			self.notify("Not connected!",
					"Not connected to server %s, cannot connect to this session." % self.server.get_display_name(),
					notification_type=NOTIFY_ERROR)
			self.serror("not connected to server! cannot request password!", session, retry)




	def prepare_session_ports(self, session, may_connect=False):
		"""
		If the session is one that can be tunnelled, calls do_prepare_session_port()
		to ensure the tunnel is setup, if may_connect is True we will connect to the
		session (after setting up the tunnel if not setup already).
		"""
		if session.session_type==SCREEN_TYPE:
			self.sdebug("no setup required for screen sessions", session, may_connect)
			return
		supported = self.settings.get_available_session_types(None, False)
		if session.session_type==X11_TYPE:
			may_connect = False		#we cant connect to those (we just setup the ports if any)
		elif session.session_type==GSTVIDEO_TYPE and self.settings.supports_gstvideo:
			pass		#(GSTVIDEO_TYPE is not in the available session types list)
		elif session.session_type not in supported:
			self.slog("session type %s is not supported (%s), not setting it up" % (session.session_type, supported), session, may_connect)
			return
		self.do_prepare_session_port(session, may_connect)

	def do_prepare_session_port(self, session, connect=False):
		"""
		This method ensures that the session port is accessible so we can bind to the session.
		If the session is not tunnelled, we just connect (if connect=True, otherwise do nothing).
		If the session is tunnelled, we try to find an existing tunnel
		(which may have been setup in advance via this same method)	and if we don't find it, we set it up.
		"""
		closed = session.status==Session.STATUS_CLOSED or session.timed_out
		if closed:
			return
		if not self.server.local and (self.server.ssh_tunnel or session.requires_tunnel) and (connect or self.server.preload_tunnels):
			""" session or server requires a tunnel and we are pre-loading or connecting: ensure it is ready"""
			if self.conch_connection==None:
				self.notify("Connection is not tunnelled",
						"Cannot create a new tunnel for session %s without a tunnelled connection!" % session.name,
						notification_type=NOTIFY_ERROR)
				return
			forwarder = self.add_session_port_forward(session, None, session.port)
			#extract source port from forward info struct:
			_, port, _, _ = forwarder
			target_host = LOCALHOST
		elif not closed:
			""" direct tcp connection """
			port = int(session.port)
			target_host = session.host
			if target_host=="" or target_host=="0.0.0.0":
				target_host = self.server.host
			self.sdebug("not using SSH tunnel, ready on (%s,%s)" % (target_host, port), session, connect)
		if connect:
			self.may_connect_sound(session)
			#this will start the client process:
			self.attach_to_session(session, target_host, port)


	def attach_to_session(self, session, host, port):
		assert self.session_attach
		try:
			self.session_attach(session, host, port)
		except Exception, e:
			self.serr(None, e, session, host, port)

	def may_connect_sound(self, session):
		self.sdebug("uses_sound_out=%s, uses_sound_in=%s" % (session.uses_sound_out, session.uses_sound_in), session)
		if session.uses_sound_out or session.uses_sound_in:
			callLater(0.5, self.request_session_sound, session)

	def request_session_sound(self, session):
		self.sdebug(None, session)
		try:
			sound_session = session
			if session.shadowed_display:
				sound_session = self.server.get_session_by_display(session.shadowed_display)
			assert sound_session
			detach_states = [Session.STATUS_CLOSED, Session.STATUS_SUSPENDED, Session.STATUS_SUSPENDING, Session.STATUS_AVAILABLE]
			check_actor_states = [Session.STATUS_CONNECTING]
			#Sound input:
			in_live = sound_session.is_sound_live(True)
			self.sdebug("tunnel_sink(%s)=%s, sound in already live=%s" % (self.server, self.server.tunnel_sink, in_live), session)
			if self.server.tunnel_sink and not in_live and session.uses_sound_in:
				def stop_sound_in():
					self.sdebug()
					sound_session.stop_sound(False)
				def may_stop_sound_in():
					self.sdebug("actor=%s, uuid=%s" % (session.actor, self.settings.uuid))
					if session.actor==self.settings.uuid:
						return True		#run again
					stop_sound_in()
				session.do_add_status_update_callback(None, detach_states, stop_sound_in, clear_it=True, timeout=None)
				session.do_add_status_update_callback(None, check_actor_states, may_stop_sound_in, clear_it=False, timeout=None)
				self.client.request_sound(sound_session, True, True, False)
			#Sound output:
			out_live = sound_session.is_sound_live(False)
			self.sdebug("tunnel_source(%s)=%s, sound out already live=%s" % (self.server, self.server.tunnel_source, out_live), session)
			if self.server.tunnel_source and not out_live and session.uses_sound_out:
				def stop_sound_out():
					self.sdebug()
					sound_session.stop_sound(True)
				def may_stop_sound_out():
					self.sdebug("actor=%s, uuid=%s" % (session.actor, self.settings.uuid))
					if session.actor==self.settings.uuid:
						return True		#run again
					stop_sound_out()
				session.do_add_status_update_callback(None, detach_states, stop_sound_out, clear_it=True, timeout=None)
				session.do_add_status_update_callback(None, check_actor_states, may_stop_sound_out, clear_it=False, timeout=None)
				self.client.request_sound(sound_session, True, False, False)
		except Exception, e:
			self.serr(None, e, session)


	"""
	This method is called when the server has set or changed the ipp or samba ports via set_tunnel_ports
	We then need to (re-)start the tunnels
	"""
	def tunnel_ports_changed(self, old_samba, old_ipp, samba_port, ipp_port):
		self.sdebug(None, old_samba, old_ipp, samba_port, ipp_port)
		if old_ipp>0 or old_samba>0:
			#FIXME: re-implement this with conch
			self.slog("old ports not freed... TODO!", old_samba, old_ipp, samba_port, ipp_port)

		if samba_port>0 and self.settings.tunnel_fs and self.settings.local_samba_port>0:
			self.add_port_forward(self.settings.local_samba_port, LOCALHOST, samba_port, True)
		if ipp_port>0 and self.settings.tunnel_printer and self.settings.local_printer_port>0:
			self.add_port_forward(self.settings.local_printer_port, LOCALHOST, ipp_port, True)

		if FORCE_SEND_MOUNT_POINTS or (self.server.tunnel_fs and (not self.server.local or SEND_MOUNT_POINTS_TO_LOCAL)):
			self.client.send_mount_points()
