#!/usr/bin/env python
"""Functions to run individual GRR components during self-contained testing."""

import atexit
import collections
from collections.abc import Iterable
import os
import signal
import subprocess
import sys
import tempfile
import threading
import time
from typing import Optional, Union

import portpicker

from google.protobuf import text_format
from grr_response_core.lib import package
from grr_response_test.lib import api_helpers
from fleetspeak.src.client.daemonservice.proto.fleetspeak_daemonservice import config_pb2 as daemonservice_config_pb2
from fleetspeak.src.client.generic.proto.fleetspeak_client_generic import config_pb2 as client_config_pb2
from fleetspeak.src.common.proto.fleetspeak import system_pb2
from fleetspeak.src.config.proto.fleetspeak_config import config_pb2
from fleetspeak.src.server.grpcservice.proto.fleetspeak_grpcservice import grpcservice_pb2
from fleetspeak.src.server.proto.fleetspeak_server import server_pb2
from fleetspeak.src.server.proto.fleetspeak_server import services_pb2

ComponentOptions = dict[str, Union[int, str]]


class Error(Exception):
  """Module-specific base error class."""


class ConfigInitializationError(Error):
  """Raised when a self-contained config can't be written."""


def _GetServerComponentArgs(config_path: str) -> list[str]:
  """Returns a set of command line arguments for server components.

  Args:
    config_path: Path to a config path generated by
      self_contained_config_writer.

  Returns:
    An iterable with command line arguments to use.
  """

  primary_config_path = package.ResourcePath(
      "grr-response-core", "install_data/etc/grr-server.yaml")
  secondary_config_path = package.ResourcePath(
      "grr-response-test", "grr_response_test/test_data/grr_test.yaml")

  monitoring_port = portpicker.pick_unused_port()

  return [
      "--config",
      primary_config_path,
      "--secondary_configs",
      ",".join([secondary_config_path, config_path]),
      "-p",
      f"Monitoring.http_port={monitoring_port}",
      "-p",
      f"Monitoring.http_port_max={monitoring_port+10}",
      "-p",
      "AdminUI.webauth_manager=NullWebAuthManager",
  ]


def _GetRunEndToEndTestsArgs(
    client_id,
    server_config_path,
    tests: Optional[Iterable[str]] = None,
    manual_tests: Optional[Iterable[str]] = None,
) -> list[str]:
  """Returns arguments needed to configure run_end_to_end_tests process.

  Args:
    client_id: String with a client id pointing to an already running client.
    server_config_path: Path to the server configuration file.
    tests: (Optional) List of tests to run.
    manual_tests: (Optional) List of manual tests to not skip.

  Returns:
    An iterable with command line arguments.
  """
  port = api_helpers.GetAdminUIPortFromConfig(server_config_path)

  api_endpoint = "http://localhost:%d" % port
  args = [
      "--api_endpoint",
      api_endpoint,
      "--api_user",
      "admin",
      "--api_password",
      "admin",
      "--client_id",
      client_id,
      "--ignore_test_context",
      "True",
  ]
  if tests is not None:
    args += ["--run_only_tests", ",".join(tests)]
  if manual_tests is not None:
    args += ["--manual_tests", ",".join(manual_tests)]

  return args


def _StartBinary(binary_path: str, args: list[str]) -> subprocess.Popen[bytes]:
  """Starts a new process with a given binary and args.

  Started subprocess will be killed automatically on exit.

  Args:
    binary_path: A binary to run.
    args: An iterable with program arguments (not containing the program
      executable).

  Returns:
    Popen object corresponding to a started process.
  """

  popen_args = [binary_path] + args
  print("Starting binary: " + " ".join(popen_args))
  process = subprocess.Popen(
      popen_args, bufsize=0, stdout=None, stderr=subprocess.STDOUT)

  def KillOnExit():
    if process.poll() is None:
      process.kill()
      process.wait()

  atexit.register(KillOnExit)

  return process


def _StartComponent(
    main_package: str, args: list[str]
) -> subprocess.Popen[bytes]:
  """Starts a new process with a given component.

  This starts a Python interpreter with a "-u" argument (to turn off output
  buffering) and with a "-m" argument followed by the main package name, thus
  effectively executing the main() function of a given package.

  Args:
    main_package: Main package path.
    args: An iterable with program arguments (not containing the program
      executable).

  Returns:
    Popen object corresponding to a started process.
  """
  popen_args = [sys.executable, "-u", "-m", main_package] + args
  print("Starting %s component: %s" % (main_package, " ".join(popen_args)))
  process = subprocess.Popen(
      popen_args, bufsize=0, stdout=None, stderr=subprocess.STDOUT)
  print("Component %s pid: %d" % (main_package, process.pid))

  def KillOnExit():
    if process.poll() is None:
      print("Killing %s." % main_package)
      process.kill()
      process.wait()

  atexit.register(KillOnExit)

  return process


GRRConfigs = collections.namedtuple("GRRConfigs", [
    "server_config",
    "client_config",
])


def InitGRRConfigs(
    mysql_database: str,
    mysql_username: Optional[str] = None,
    mysql_password: Optional[str] = None,
    logging_path: Optional[str] = None,
    osquery_path: Optional[str] = None,
) -> GRRConfigs:
  """Initializes server and client config files."""

  # Create 2 temporary files to contain server and client configuration files
  # that we're about to generate.
  #
  # TODO(user): migrate to TempFilePath as soon grr.test_lib is moved to
  # grr_response_test.
  fd, built_server_config_path = tempfile.mkstemp(".yaml")
  os.close(fd)
  print("Using temp server config path: %s" % built_server_config_path)
  fd, built_client_config_path = tempfile.mkstemp(".yaml")
  os.close(fd)
  print("Using temp client config path: %s" % built_client_config_path)

  def CleanUpConfigs():
    os.remove(built_server_config_path)
    os.remove(built_client_config_path)

  atexit.register(CleanUpConfigs)

  # Generate server and client configs.
  config_writer_flags = [
      "--dest_server_config_path",
      built_server_config_path,
      "--dest_client_config_path",
      built_client_config_path,
      "--config_mysql_database",
      mysql_database,
  ]

  if mysql_username is not None:
    config_writer_flags.extend(["--config_mysql_username", mysql_username])

  if mysql_password is not None:
    config_writer_flags.extend(["--config_mysql_password", mysql_password])

  if logging_path is not None:
    config_writer_flags.extend(["--config_logging_path", logging_path])

  if osquery_path is not None:
    config_writer_flags.extend(["--config_osquery_path", osquery_path])

  p = _StartComponent(
      "grr_response_test.lib.self_contained_config_writer",
      config_writer_flags)
  if p.wait() != 0:
    raise ConfigInitializationError("ConfigWriter execution failed: {}".format(
        p.returncode))

  return GRRConfigs(built_server_config_path, built_client_config_path)


FleetspeakConfigs = collections.namedtuple(
    "FleetspeakConfigs",
    [
        "server_components_config",
        "server_services_config",
        "client_config",
        "logging_path",
    ],
)


def InitFleetspeakConfigs(
    grr_configs: GRRConfigs,
    mysql_database: str,
    mysql_username: Optional[str] = None,
    mysql_password: Optional[str] = None,
    logging_path: Optional[str] = None,
) -> FleetspeakConfigs:
  """Initializes Fleetspeak server and client configs."""

  fs_frontend_port, fs_admin_port = api_helpers.GetFleetspeakPortsFromConfig(
      grr_configs.server_config)

  mysql_username = mysql_username or ""
  mysql_password = mysql_password or ""

  temp_root = tempfile.mkdtemp(suffix="_fleetspeak")

  def TempPath(*args):
    return os.path.join(temp_root, *args)

  cp = config_pb2.Config(configuration_name="Self-contained testing")
  cp.components_config.mysql_data_source_name = "%s:%s@tcp(127.0.0.1:3306)/%s" % (
      mysql_username, mysql_password, mysql_database)
  cp.components_config.https_config.listen_address = "localhost:%d" % portpicker.pick_unused_port(
  )
  cp.components_config.admin_config.listen_address = ("localhost:%d" %
                                                      fs_admin_port)
  cp.public_host_port.append(cp.components_config.https_config.listen_address)
  cp.server_component_configuration_file = TempPath("server.config")
  cp.trusted_cert_file = TempPath("trusted_cert.pem")
  cp.trusted_cert_key_file = TempPath("trusted_cert_key.pem")
  cp.server_cert_file = TempPath("server_cert.pem")
  cp.server_cert_key_file = TempPath("server_cert_key.pem")
  cp.linux_client_configuration_file = TempPath("linux_client.config")
  cp.windows_client_configuration_file = TempPath("windows_client.config")
  cp.darwin_client_configuration_file = TempPath("darwin_client.config")

  built_configurator_config_path = TempPath("configurator.config")
  with open(built_configurator_config_path, mode="w", encoding="utf-8") as fd:
    fd.write(text_format.MessageToString(cp))

  p = _StartBinary(
      "fleetspeak-config",
      ["--logtostderr", "--config", built_configurator_config_path])
  if p.wait() != 0:
    raise ConfigInitializationError(
        "fleetspeak-config execution failed: {}".format(p.returncode))

  # Adjust client config.
  with open(
      cp.linux_client_configuration_file, mode="r", encoding="utf-8") as fd:
    conf_content = fd.read()
  conf = text_format.Parse(conf_content, client_config_pb2.Config())
  conf.filesystem_handler.configuration_directory = temp_root
  conf.filesystem_handler.state_file = TempPath("client.state")
  with open(
      cp.linux_client_configuration_file, mode="w", encoding="utf-8") as fd:
    fd.write(text_format.MessageToString(conf))

  # Write client services configuration.
  service_conf = system_pb2.ClientServiceConfig(name="GRR", factory="Daemon")
  payload = daemonservice_config_pb2.Config()
  payload.argv.extend([
      sys.executable,
      "-u",
      "-m",
      "grr_response_client.grr_fs_client",
      "--config",
      grr_configs.client_config,
      "--verbose",
  ])

  payload.monitor_heartbeats = True
  payload.heartbeat_unresponsive_grace_period_seconds = 45
  payload.heartbeat_unresponsive_kill_period_seconds = 120
  service_conf.config.Pack(payload)

  os.mkdir(TempPath("textservices"))
  with open(
      TempPath("textservices", "GRR.textproto"), mode="w",
      encoding="utf-8") as fd:
    fd.write(text_format.MessageToString(service_conf))

  # Server services configuration.
  service_config = services_pb2.ServiceConfig(name="GRR", factory="GRPC")
  grpc_config = grpcservice_pb2.Config(
      target="localhost:%d" % fs_frontend_port, insecure=True)
  service_config.config.Pack(grpc_config)
  server_conf = server_pb2.ServerConfig(services=[service_config])
  server_conf.broadcast_poll_time.seconds = 1

  built_server_services_config_path = TempPath("server.services.config")
  with open(
      built_server_services_config_path, mode="w", encoding="utf-8") as fd:
    fd.write(text_format.MessageToString(server_conf))

  return FleetspeakConfigs(
      cp.server_component_configuration_file,
      built_server_services_config_path,
      cp.linux_client_configuration_file,
      logging_path,
  )


def StartServerProcesses(
    grr_configs: GRRConfigs,
    fleetspeak_configs: FleetspeakConfigs,
) -> list[subprocess.Popen]:
  """Starts GRR server processes (optionally behind Fleetspeak frontend)."""

  fleetspeak_server_args = [
      "-v",
      "2",
      "-components_config",
      fleetspeak_configs.server_components_config,
      "-services_config",
      fleetspeak_configs.server_services_config,
  ]
  if fleetspeak_configs.logging_path is not None:
    fleetspeak_server_args.extend(["-log_dir", fleetspeak_configs.logging_path])

  def GrrArgs():
    return _GetServerComponentArgs(grr_configs.server_config)

  return [
      _StartBinary(
          "fleetspeak-server",
          fleetspeak_server_args,
      ),
      _StartComponent(
          "grr_response_server.bin.fleetspeak_frontend",
          GrrArgs(),
      ),
      _StartComponent(
          "grr_response_server.gui.admin_ui",
          GrrArgs(),
      ),
      _StartComponent(
          "grr_response_server.bin.worker",
          GrrArgs(),
      ),
  ]


def StartClientProcess(
    fleetspeak_configs: FleetspeakConfigs,
) -> subprocess.Popen[bytes]:
  """Starts a GRR client or Fleetspeak client configured to run GRR."""

  fleetspeak_client_args = [
      "-v",
      "2",
      "-std_forward",
      "-config",
      fleetspeak_configs.client_config,
  ]

  if fleetspeak_configs.logging_path is not None:
    fleetspeak_client_args.extend(["-log_dir", fleetspeak_configs.logging_path])

  return _StartBinary(
      "fleetspeak-client",
      fleetspeak_client_args,
  )


def RunEndToEndTests(client_id: str,
                     server_config_path: str,
                     tests: Optional[Iterable[str]] = None,
                     manual_tests: Optional[Iterable[str]] = None):
  """Runs end to end tests on a given client."""
  p = _StartComponent(
      "grr_response_test.run_end_to_end_tests",
      _GetServerComponentArgs(server_config_path) + _GetRunEndToEndTestsArgs(
          client_id, server_config_path, tests=tests,
          manual_tests=manual_tests))
  if p.wait() != 0:
    raise RuntimeError("RunEndToEndTests execution failed.")


_PROCESS_CHECK_INTERVAL = 0.1


def _DieIfSubProcessDies(processes: Iterable[subprocess.Popen],
                         already_dead_event: threading.Event):
  """Synchronously waits for processes and dies if one dies."""
  while True:
    for p in processes:
      if p.poll() not in [None, 0]:
        # Prevent a double kill. When the main process exits, it kills the
        # children. We don't want a child's death to cause a SIGTERM being
        # sent to a process that's already exiting.
        if already_dead_event.is_set():
          return

        # DieIfSubProcessDies runs in a background thread, raising an exception
        # will just kill the thread while what we want is to fail the whole
        # process.
        print("Subprocess %s died unexpectedly. Killing main process..." %
              p.pid)
        for kp in processes:
          try:
            os.kill(kp.pid, signal.SIGTERM)
          except OSError:
            pass
        # sys.exit only exits a thread when called from a thread.
        # Killing self with SIGTERM to ensure the process runs necessary
        # cleanups before exiting.
        os.kill(os.getpid(), signal.SIGTERM)
    time.sleep(_PROCESS_CHECK_INTERVAL)


def DieIfSubProcessDies(
    processes: Iterable[subprocess.Popen]) -> threading.Thread:
  """Kills the process if any of given processes dies.

  This function is supposed to run in a background thread and monitor provided
  processes to ensure they don't die silently.

  Args:
    processes: An iterable with multiprocessing.Process instances.

  Returns:
    Background thread started to monitor the processes.
  """
  already_dead_event = threading.Event()
  t = threading.Thread(
      target=_DieIfSubProcessDies, args=(processes, already_dead_event))
  t.daemon = True
  t.start()

  def PreventDoubleDeath():
    already_dead_event.set()

  atexit.register(PreventDoubleDeath)

  return t
