#!/usr/bin/env python
"""Fleetspeak-facing client related functionality.

This module contains glue code necessary for Fleetspeak and the GRR client
to work together.
"""

import logging
import pdb
import queue
import struct
import threading
import time
import zlib

from absl import flags

from grr_response_client import comms
from grr_response_core import config
from grr_response_core.lib import rdfvalue
from grr_response_core.lib.rdfvalues import flows as rdf_flows
from grr_response_core.lib.rdfvalues import protodict as rdf_protodict
from grr_response_proto import jobs_pb2
from fleetspeak.src.common.proto.fleetspeak import common_pb2 as fs_common_pb2
from fleetspeak.client_connector import connector as fs_client


# fmt: off

START_STRING = "Starting client."

# /grr_response_client/comms.py)
# fmt: on

# Limit on the total size of GrrMessages to batch into a single
# PackedMessageList (before sending to Fleetspeak).
_MAX_MSG_LIST_BYTES = 1 << 20  # 1 MiB

# Maximum number of GrrMessages to put in one PackedMessageList.
_MAX_MSG_LIST_MSG_COUNT = 100

# Maximum size of annotations to add for a Fleetspeak message.
_MAX_ANNOTATIONS_BYTES = 3 << 10  # 3 KiB

_DATA_IDS_ANNOTATION_KEY = "data_ids"


class FatalError(Exception):
  pass


class BrokenFSConnectionError(Exception):
  pass


def _EncodeMessageList(
    message_list: rdf_flows.MessageList,
    packed_message_list: rdf_flows.PackedMessageList,
) -> None:
  """Encode the MessageList into the packed_message_list rdfvalue."""
  # By default uncompress
  uncompressed_data = message_list.SerializeToBytes()
  packed_message_list.message_list = uncompressed_data

  compressed_data = zlib.compress(uncompressed_data)

  # Only compress if it buys us something.
  if len(compressed_data) < len(uncompressed_data):
    packed_message_list.compression = (
        rdf_flows.PackedMessageList.CompressionType.ZCOMPRESSION
    )
    packed_message_list.message_list = compressed_data


class GRRFleetspeakClient(object):
  """A Fleetspeak enabled client implementation."""

  # Only buffer at most ~100MB of data - the estimate comes from the Fleetspeak
  # message size limit - Fleetspeak refuses to process messages larger than 2MB.
  # This is a sanity safeguard against unlimited memory consumption.
  _SENDER_QUEUE_MAXSIZE = 50

  def __init__(self):
    self._fs = fs_client.FleetspeakConnection(
        version=config.CONFIG["Source.version_string"]
    )

    self._sender_queue = queue.Queue(
        maxsize=GRRFleetspeakClient._SENDER_QUEUE_MAXSIZE
    )

    self._threads = {}

    # The client worker does all the real work here.
    # In particular, we delegate sending messages to Fleetspeak to a separate
    # threading.Thread here.
    out_queue = _FleetspeakQueueForwarder(self._sender_queue)
    worker = self._threads["Worker"] = comms.GRRClientWorker(
        out_queue=out_queue, heart_beat_cb=self._fs.Heartbeat, client=self
    )
    # TODO(user): this is an ugly way of passing the heartbeat callback to
    # the queue. Refactor the heartbeat callback initialization logic so that
    # this won't be needed.
    out_queue.heart_beat_cb = worker.Heartbeat

    self._threads["Foreman"] = self._CreateThread(self._ForemanOp)
    self._threads["Sender"] = self._CreateThread(self._SendOp)
    self._threads["Receiver"] = self._CreateThread(self._ReceiveOp)

  def _CreateThread(self, loop_op):
    thread = threading.Thread(target=self._RunInLoop, args=(loop_op,))
    thread.daemon = True
    return thread

  def _RunInLoop(self, loop_op):
    """Runs the loop_op function in an endless loop."""
    while True:
      try:
        loop_op()
      except BrokenFSConnectionError as e:
        # This happens during Fleetspeak shutdown and was already logged in the
        # receiver thread so we skip the additional stack trace here.
        raise e
      except Exception as e:
        logging.critical("Fatal error occurred:", exc_info=True)
        if flags.FLAGS.pdb_post_mortem:
          pdb.post_mortem()
        # This will terminate execution in the current thread.
        raise e

  def Run(self):
    """The main run method of the client."""
    for thread in self._threads.values():
      thread.start()
    logging.info(START_STRING)

    while True:
      dead_threads = [
          tn for (tn, t) in self._threads.items() if not t.is_alive()
      ]
      if dead_threads:
        raise FatalError(
            "These threads are dead: %r. Shutting down..." % dead_threads
        )
      time.sleep(10)

  def _ForemanOp(self):
    """Sends Foreman checks periodically."""
    period = config.CONFIG["Client.foreman_check_frequency"]
    self._threads["Worker"].SendReply(
        rdf_protodict.DataBlob(),
        session_id=rdfvalue.FlowSessionID(flow_name="Foreman"),
    )
    time.sleep(period)

  def _SendMessages(self, grr_msgs, background=False):
    """Sends a block of messages through Fleetspeak."""
    message_list = rdf_flows.PackedMessageList()
    _EncodeMessageList(rdf_flows.MessageList(job=grr_msgs), message_list)
    fs_msg = fs_common_pb2.Message(
        message_type="MessageList",
        destination=fs_common_pb2.Address(service_name="GRR"),
        background=background,
    )
    fs_msg.data.Pack(message_list.AsPrimitiveProto())

    for grr_msg in grr_msgs:
      if (
          grr_msg.session_id is None
          or grr_msg.request_id is None
          or grr_msg.response_id is None
      ):
        continue
      # Place all ids in a single annotation, instead of having separate
      # annotations for the flow-id, request-id and response-id. This reduces
      # overall size of the annotations by half (~60 bytes to ~30 bytes).
      annotation = fs_msg.annotations.entries.add()
      annotation.key = _DATA_IDS_ANNOTATION_KEY
      annotation.value = "%s:%d:%d" % (
          grr_msg.session_id.Basename(),
          grr_msg.request_id,
          grr_msg.response_id,
      )
      if fs_msg.annotations.ByteSize() >= _MAX_ANNOTATIONS_BYTES:
        break

    try:
      self._fs.Send(fs_msg)
    except (OSError, struct.error) as e:
      logging.critical("Broken local Fleetspeak connection (write end): %s", e)
      raise

  def _SendOp(self):
    """Sends messages through Fleetspeak."""
    msg = self._sender_queue.get()
    msgs = []
    msgs.append(msg)

    count = 1
    size = len(msg.SerializeToBytes())

    while count < _MAX_MSG_LIST_MSG_COUNT and size < _MAX_MSG_LIST_BYTES:
      try:
        msg = self._sender_queue.get(timeout=1)
        msgs.append(msg)
        count += 1
        size += len(msg.SerializeToBytes())
      except queue.Empty:
        break

    if msgs:
      self._SendMessages(msgs)

  def _ReceiveOp(self):
    """Receives a single message through Fleetspeak."""
    try:
      fs_msg, _ = self._fs.Recv()
    except (OSError, struct.error) as e:
      logging.critical("Broken local Fleetspeak connection (read end): %s", e)
      raise BrokenFSConnectionError() from e

    received_type = fs_msg.data.TypeName()
    if not received_type.endswith("GrrMessage"):
      raise ValueError(
          "Unexpected proto type received through Fleetspeak: %r; expected "
          "grr.GrrMessage." % received_type
      )

    grr_msg = rdf_flows.GrrMessage.FromSerializedBytes(fs_msg.data.value)
    # Authentication is ensured by Fleetspeak.
    grr_msg.auth_state = jobs_pb2.GrrMessage.AUTHENTICATED

    self._threads["Worker"].QueueMessages([grr_msg])


class _FleetspeakQueueForwarder(object):
  """Ducktyped replacement for SizeLimitedQueue; forwards to _SenderThread."""

  def __init__(self, sender_queue):
    """Constructor.

    Args:
      sender_queue: queue.Queue
    """
    self._sender_queue = sender_queue
    self.heart_beat_cb = lambda: None

  def Put(self, grr_msg, block=True, timeout=None):
    """Places a message in the queue."""
    if not block:
      self._sender_queue.put(grr_msg, block=False)
    else:
      t0 = time.time()
      while not timeout or (time.time() - t0 < timeout):
        self.heart_beat_cb()
        try:
          self._sender_queue.put(grr_msg, timeout=1)
          return
        except queue.Full:
          continue

      raise queue.Full

  def Get(self):
    raise NotImplementedError("This implementation only supports input.")

  def Size(self):
    """Returns the *approximate* size of the queue.

    See: https://docs.python.org/2/library/queue.html#Queue.Queue.qsize

    Returns:
      int
    """
    return self._sender_queue.qsize()

  def Full(self):
    return self._sender_queue.full()
