import json
import os
import shlex
import subprocess
import sys
import tempfile
import threading
import time
from collections import deque
from typing import Optional, Dict, Any

from ray_release.anyscale_util import LAST_LOGS_LENGTH

from ray_release.cluster_manager.cluster_manager import ClusterManager
from ray_release.exception import (
    ResultsError,
    LocalEnvSetupError,
    ClusterNodesWaitTimeout,
    CommandTimeout,
    ClusterStartupError,
    CommandError,
)
from ray_release.file_manager.file_manager import FileManager
from ray_release.logger import logger
from ray_release.command_runner.command_runner import CommandRunner
from ray_release.wheels import install_matching_ray_locally


def install_cluster_env_packages(cluster_env: Dict[Any, Any]):
    os.environ.update(cluster_env.get("env_vars", {}))
    packages = cluster_env["python"]["pip_packages"]
    logger.info(f"Installing cluster env packages locally: {packages}")

    for package in packages:
        subprocess.check_output(
            f"pip install -U {shlex.quote(package)}",
            shell=True,
            env=os.environ,
            text=True,
        )


class ClientRunner(CommandRunner):
    def __init__(
        self,
        cluster_manager: ClusterManager,
        file_manager: FileManager,
        working_dir: str,
    ):
        super(ClientRunner, self).__init__(cluster_manager, file_manager, working_dir)

        self.last_logs = None
        self.result_output_json = tempfile.mktemp()

    def prepare_remote_env(self):
        pass

    def prepare_local_env(self, ray_wheels_url: Optional[str] = None):
        try:
            install_matching_ray_locally(
                ray_wheels_url or os.environ.get("RAY_WHEELS", None)
            )
            install_cluster_env_packages(self.cluster_manager.cluster_env)
        except Exception as e:
            raise LocalEnvSetupError(f"Error setting up local environment: {e}") from e

    def wait_for_nodes(self, num_nodes: int, timeout: float = 900):
        import ray

        ray_address = self.cluster_manager.get_cluster_address()
        try:
            if ray.is_initialized:
                ray.shutdown()

            ray.init(address=ray_address)

            start_time = time.monotonic()
            timeout_at = start_time + timeout
            next_status = start_time + 30
            nodes_up = len(ray.nodes())
            while nodes_up < num_nodes:
                now = time.monotonic()
                if now >= timeout_at:
                    raise ClusterNodesWaitTimeout(
                        f"Only {len(ray.nodes())}/{num_nodes} are up after "
                        f"{timeout} seconds."
                    )

                if now >= next_status:
                    logger.info(
                        f"Waiting for nodes to come up: "
                        f"{len(ray.nodes())}/{num_nodes} "
                        f"({now - start_time:.2f} seconds, "
                        f"timeout: {timeout} seconds)."
                    )
                    next_status += 30

                time.sleep(1)
                nodes_up = len(ray.nodes())

            ray.shutdown()
        except Exception as e:
            raise ClusterStartupError(f"Exception when waiting for nodes: {e}") from e

        logger.info(f"All {num_nodes} nodes are up.")

    def get_last_logs(self) -> Optional[str]:
        return self.last_logs

    def fetch_results(self) -> Dict[str, Any]:
        try:
            with open(self.result_output_json, "rt") as fp:
                return json.load(fp)
        except Exception as e:
            raise ResultsError(
                f"Could not load local results from " f"client command: {e}"
            ) from e

    def run_command(
        self, command: str, env: Optional[Dict] = None, timeout: float = 3600.0
    ) -> float:
        logger.info(
            f"Running command using Ray client on cluster "
            f"{self.cluster_manager.cluster_name}: {command}"
        )

        env = env or {}
        full_env = self.get_full_command_env(
            {
                **os.environ,
                **env,
                "RAY_ADDRESS": self.cluster_manager.get_cluster_address(),
                "RAY_JOB_NAME": "test_job",
                "PYTHONUNBUFFERED": "1",
            }
        )

        kill_event = threading.Event()

        def _kill_after(
            proc: subprocess.Popen,
            timeout: int = 30,
            kill_event: Optional[threading.Event] = None,
        ):
            timeout_at = time.monotonic() + timeout
            while time.monotonic() < timeout_at:
                if proc.poll() is not None:
                    return
                time.sleep(1)
            logger.info(
                f"Client command timed out after {timeout} seconds, "
                f"killing subprocess."
            )
            if kill_event:
                kill_event.set()
            proc.terminate()

        start_time = time.monotonic()
        proc = subprocess.Popen(
            command,
            env=full_env,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            shell=True,
            text=True,
        )

        kill_thread = threading.Thread(
            target=_kill_after, args=(proc, timeout, kill_event)
        )
        kill_thread.start()

        proc.stdout.reconfigure(line_buffering=True)
        sys.stdout.reconfigure(line_buffering=True)
        logs = deque(maxlen=LAST_LOGS_LENGTH)
        for line in proc.stdout:
            logs.append(line)
            sys.stdout.write(line)
        proc.wait()
        sys.stdout.reconfigure(line_buffering=False)
        time_taken = time.monotonic() - start_time
        self.last_logs = "\n".join(logs)

        return_code = proc.poll()
        if return_code == -15 or return_code == 15 or kill_event.is_set():
            # Process has been terminated
            raise CommandTimeout(f"Cluster command timed out after {timeout} seconds.")
        if return_code != 0:
            raise CommandError(f"Command returned non-success status: {return_code}")

        logger.warning(f"WE GOT RETURN CODE {return_code} AFTER {time_taken}")

        return time_taken