from ableton.v2.control_surface import ControlSurface

from . import live

import importlib
import traceback
import logging
import os

logger = logging.getLogger("grip")

class Manager(ControlSurface):
    def __init__(self, c_instance):
        ControlSurface.__init__(self, c_instance)

        self.log_level = "info"

        self.handlers = []

        try:
            self.osc_server = live.OSCServer()
            self.schedule_message(0, self.tick)

            self.start_logging()
            self.init_api()

            self.show_message("Grip: Listening on ports %s" % live.GRIP_OSC_LISTEN_PORTS)
            logger.info("Started Grip on ports %s, responding to %s" % (live.GRIP_OSC_LISTEN_PORTS, live.GRIP_OSC_RESPONSE_PORTS))
        except OSError as msg:
            self.show_message("Grip: Couldn't bind to port %d (%s)" % (live.GRIP_OSC_LISTEN_PORT, msg))
            logger.info("Couldn't bind to port %d (%s)" % (live.GRIP_OSC_LISTEN_PORT, msg))


    def start_logging(self):
        """
        Start logging to a local logfile (logs/grip.log),
        and relay error messages via OSC.
        """
        module_path = os.path.dirname(os.path.realpath(__file__))
        log_dir = os.path.join(module_path, "logs")
        if not os.path.exists(log_dir):
            os.mkdir(log_dir, 0o755)
        log_path = os.path.join(log_dir, "grip.log")
        logger.info("Grip log file: %s" % log_path)
        self.show_message("Grip log: %s" % log_path)
        self.log_file_handler = logging.FileHandler(log_path)
        self.log_file_handler.setLevel(self.log_level.upper())
        formatter = logging.Formatter('(%(asctime)s) [%(levelname)s] %(message)s')
        self.log_file_handler.setFormatter(formatter)
        logger.addHandler(self.log_file_handler)

        class LiveOSCErrorLogHandler(logging.StreamHandler):
            def emit(handler, record):
                message = record.getMessage()
                message = message[message.index(":") + 2:]
                try:
                    self.osc_server.send("/live/error", (message,))
                except OSError:
                    # If the connection is dead, silently ignore errors as there's not much more we can do
                    pass
        self.live_osc_error_handler = LiveOSCErrorLogHandler()
        self.live_osc_error_handler.setLevel(logging.ERROR)
        logger.addHandler(self.live_osc_error_handler)

        #--------------------------------------------------------------------------------
        # Grip: Forward all logs to companion via OSC
        # Broadcasts to all response ports (prod: 47141, dev: 47241)
        #--------------------------------------------------------------------------------
        class GripLogHandler(logging.Handler):
            def __init__(self, osc_server):
                super().__init__()
                self.osc_server = osc_server
            
            def emit(self, record):
                try:
                    level_name = record.levelname.lower()
                    message = record.getMessage()
                    # Broadcast to all daemon receive ports so both prod and dev daemons get logs
                    self.osc_server.send("/live/log", (level_name, message))
                except OSError:
                    # If the connection is dead, silently ignore errors
                    pass
                except Exception:
                    # Don't let log forwarding break Grip
                    pass
        
        self.grip_log_handler = GripLogHandler(self.osc_server)
        self.grip_log_handler.setLevel(logging.DEBUG)  # Forward all log levels
        logger.addHandler(self.grip_log_handler)

    def stop_logging(self):
        logger.removeHandler(self.log_file_handler)
        logger.removeHandler(self.live_osc_error_handler)
        logger.removeHandler(self.grip_log_handler)

    def init_api(self):
        def test_callback(params):
            self.show_message("Received OSC OK")
            self.osc_server.send("/live/test", ("ok",))
        def reload_callback(params):
            self.reload_imports()
        def get_log_level_callback(params):
            return (self.log_level,)
        def set_log_level_callback(params):
            log_level = params[0]
            assert log_level in ("debug", "info", "warning", "error", "critical")
            self.log_level = log_level
            self.log_file_handler.setLevel(self.log_level.upper())

        def clear_all_listeners_callback(params):
            """Clear all listeners across all handlers for clean reconnection."""
            for handler in self.handlers:
                handler._clear_listeners()
            logger.info("Cleared all listeners for reconnection")

        self.osc_server.add_handler("/live/test", test_callback)
        self.osc_server.add_handler("/live/api/reload", reload_callback)
        self.osc_server.add_handler("/live/api/clear_all_listeners", clear_all_listeners_callback)
        self.osc_server.add_handler("/live/api/get/log_level", get_log_level_callback)
        self.osc_server.add_handler("/live/api/set/log_level", set_log_level_callback)

        with self.component_guard():
            self.handlers = [
                live.SongHandler(self),
                live.ApplicationHandler(self),
                live.ClipHandler(self),
                live.ClipSlotHandler(self),
                live.TrackHandler(self),
                live.DeviceHandler(self),
                live.ViewHandler(self),
                live.SceneHandler(self),
                live.PingHandler(self),
            ]

    def clear_api(self):
        self.osc_server.clear_handlers()
        for handler in self.handlers:
            handler.clear_api()

    def tick(self):
        """
        Called once per 100ms "tick".
        Live's embedded Python implementation does not appear to support threading,
        and beachballs when a thread is started. Instead, this approach allows long-running
        processes such as the OSC server to perform operations.
        """
        logger.debug("Tick...")
        self.osc_server.process()
        self.schedule_message(1, self.tick)

    def reload_imports(self):
        try:
            importlib.reload(live.application)
            importlib.reload(live.clip)
            importlib.reload(live.clip_slot)
            importlib.reload(live.device)
            importlib.reload(live.handler)
            importlib.reload(live.osc_server)
            importlib.reload(live.scene)
            importlib.reload(live.song)
            importlib.reload(live.track)
            importlib.reload(live.view)
            importlib.reload(live)
        except Exception as e:
            exc = traceback.format_exc()
            logging.warning(exc)

        self.clear_api()
        self.init_api()
        logger.info("Reloaded code")

    def disconnect(self):
        self.show_message("Disconnecting...")
        logger.info("Disconnecting...")
        self.stop_logging()
        self.osc_server.shutdown()
        super().disconnect()

