#! /usr/bin/python3

import socket
import time
import argparse
import sys
import threading
import re

import logging

# initialize the logger
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter( logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] %(message)s') )
logger.addHandler( handler )
logger.setLevel(logging.INFO)

quit = False
# lock is used to protect the PeriodController call
lock = threading.Lock()

class SyncSocket:
    def __init__(self, timeout, host, port):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.dest = (host, port)
        self.timeout = timeout
        self.last_call = time.time()

    def connect(self):
        begin = time.time()
        connected = False
        while not connected and (time.time() - begin < self.timeout):
            try:
                self.sock.connect(self.dest)
                connected = True
            except ConnectionRefusedError:
                time.sleep(0.5)
        return connected

    def send(self):
        self.sock.send(b'sync')
        now = time.time()
        logger.debug("Send sync message after %f" % (now - self.last_call))
        self.last_call = now


class InputThread(threading.Thread):
    def __init__(self):
        threading.Thread.__init__(self)
        self._set_period_pattern = re.compile("^set_period (\d+\.?\d*)")

    def run(self):
        global quit
        while not quit:
            try:
                text_in = input()
                self._process(text_in)
            except EOFError:
                quit = True

    def _process(self, text_in):
        if text_in == "quit":
            # unlocked access, but should be fine
            global quit
            quit = True
            return

        set_period_match = self._set_period_pattern.match(text_in)
        if set_period_match:
            new_period = float(set_period_match.group(1))
            if new_period < 0.0:
                logger.error("Invalid period value %f" % new_period)
                return

            global lock
            global period_ctrl
            with lock:
                period_ctrl.update_period(new_period)

class PeriodController:
    """
    A simple P controller to compute the right "time" to sleep to
    enforce  the wished period
    """
    def __init__(self, wished_period):
        self.update_period(wished_period)

    def _mean(self):
        return sum(self._samples) / len(self._samples)

    def update_period(self, period):
        self._wished_period = period
        self._computed_period = period
        self._samples = []
        self._last_call = None

    def compute_period(self):
        if not self._last_call:
            self._last_call = time.time()
            return self._wished_period

        self._samples.append(time.time() - self._last_call)
        self._last_call = time.time()

        if len(self._samples) == 5:
            mean = self._mean()
            if mean != self._wished_period:
            #if abs(mean_ - self._wished_period) > 10e-5:
                adjustement = 0.5 * (self._wished_period - mean)
                self._computed_period += adjustement
                self._samples.clear()
            return self._computed_period

        else:
            return self._computed_period

        return self._computed_period


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MORSE syncer')

    parser.add_argument('-p', '--period', help="period of synchronisation (in s)", type=float)
    parser.add_argument('-t', '--timeout', default=10.0, help="connect timeout (in s)", type=float)
    parser.add_argument('-H', '--host', default="localhost", help="MORSE socket sync addr")
    parser.add_argument('-P', '--port', default=6000, help="MORSE socket sync port", type=int)
    parser.add_argument('-d', '--debug', action='store_true', help="Enable debug information")

    args = parser.parse_args()

    if args.debug:
        logger.setLevel(logging.DEBUG)

    s = SyncSocket(args.timeout, args.host, args.port)

    input_thr = InputThread()
    input_thr.start()

    period_ctrl = PeriodController(args.period)

    if not s.connect():
        logger.error("Failed to connect to (%s, %s)" % (args.host, args.port))
        quit = True
        input_thr.join()
        sys.exit()

    while not quit:
        try:
            s.send()
        except (BrokenPipeError, ConnectionResetError):
            quit = True
        with lock:
            sleep_time = period_ctrl.compute_period()
            time.sleep(sleep_time)

    input_thr.join()
