#!/usr/bin/env python
"""AFF4 RDFValue implementations.

This module contains all RDFValue implementations.

NOTE: This module uses the class registry to contain all implementations of
RDFValue class, regardless of where they are defined. To do this reliably, these
implementations must be imported _before_ the relevant classes are referenced
from this module.
"""

import abc
import calendar
import datetime
import functools
import logging
import posixpath
import re
import time
from typing import Any, Optional, Text, Union, cast
import zlib

import dateutil
from dateutil import parser

from google.protobuf import timestamp_pb2
from grr_response_core.lib import registry
from grr_response_core.lib import utils
from grr_response_core.lib.util import precondition
from grr_response_core.lib.util import random
from grr_response_core.lib.util import text

# Somewhere to keep all the late binding placeholders.
_LATE_BINDING_STORE = {}


def RegisterLateBindingCallback(target_name, callback, **kwargs):
  """Registers a callback to be invoked when the RDFValue named is declared."""
  _LATE_BINDING_STORE.setdefault(target_name, []).append((callback, kwargs))


class Error(Exception):
  """Errors generated by RDFValue parsers."""


class InitializeError(Error):
  """Raised when we can not initialize from this parameter."""


class DecodeError(InitializeError, ValueError):
  """Generated when we can not decode the data."""

  def __init__(self, msg):
    logging.debug(msg)
    super().__init__(msg)


class RDFValueMetaclass(registry.MetaclassRegistry):
  """A metaclass for managing semantic values."""

  def __init__(cls, name, bases, env_dict):  # pylint: disable=no-self-argument
    super(RDFValueMetaclass, cls).__init__(name, bases, env_dict)

    # Run and clear any late binding callbacks registered for this class.
    for callback, kwargs in _LATE_BINDING_STORE.pop(name, []):
      callback(target=cls, **kwargs)


class RDFValue(metaclass=RDFValueMetaclass):  # pylint: disable=invalid-metaclass
  """Baseclass for values.

  RDFValues are serialized to and from the data store.
  """

  # This is how the attribute will be serialized to the data store. It must
  # indicate both the type emitted by SerializeToWireFormat() and expected by
  # FromWireFormat()
  protobuf_type = "bytes"

  # URL pointing to a help page about this value type.
  context_help_url = None

  _value = None
  _prev_hash = None

  # Mark as dirty each time we modify this object.
  dirty = False

  # If this value was created as part of an AFF4 attribute, the attribute is
  # assigned here.
  attribute_instance = None

  def __init__(self):
    self._prev_hash = None

  def Copy(self):
    """Make a new copy of this RDFValue."""
    return self.__class__.FromSerializedBytes(self.SerializeToBytes())

  def __copy__(self):
    return self.Copy()

  @classmethod
  def FromWireFormat(cls, value):
    raise NotImplementedError(
        "Class {} does not implement FromWireFormat.".format(cls.__name__)
    )

  @classmethod
  def FromSerializedBytes(cls, value: bytes):
    raise NotImplementedError(
        "Class {} does not implement FromSerializedBytes.".format(cls.__name__)
    )

  # TODO: Remove legacy SerializeToWireFormat.
  def SerializeToWireFormat(self):
    """Serialize to a datastore compatible form."""
    return self.SerializeToBytes()

  @abc.abstractmethod
  def SerializeToBytes(self):
    """Serialize into a string which can be parsed using FromSerializedBytes."""

  @classmethod
  def Fields(cls):
    """Return a list of fields which can be queried from this value."""
    return []

  def __eq__(self, other):
    return self._value == other

  def __ne__(self, other):
    return not self.__eq__(other)

  def __hash__(self):
    new_hash = hash(self.SerializeToBytes())
    if self._prev_hash is not None and new_hash != self._prev_hash:
      raise AssertionError(
          "Usage of {} violates Python data model: hash() has changed! Usage "
          "of RDFStructs as members of sets or keys of dicts is discouraged. "
          "If used anyway, mutating is prohibited, because it causes the hash "
          "to change. Be aware that accessing unset fields can trigger a "
          "mutation.".format(type(self).__name__)
      )
    else:
      self._prev_hash = new_hash
      return new_hash

  def __bool__(self):
    return bool(self._value)

  def __str__(self):
    """Ignores the __repr__ override below to avoid indefinite recursion."""
    return super().__repr__()

  def __repr__(self):
    content = str(self)

    # Note %r, which prevents nasty nonascii characters from being printed,
    # including dangerous terminal escape sequences.
    return "<%s(%r)>" % (self.__class__.__name__, content)


RDFValue.classes["bool"] = bool
RDFValue.classes["RDFBool"] = bool


class RDFPrimitive(RDFValue):
  """An immutable RDFValue that wraps a primitive value (e.g. int)."""

  _primitive_value = None

  def __init__(self, initializer):
    super().__init__()
    self._primitive_value = initializer

  @property
  def _value(self):
    return self._primitive_value

  @classmethod
  def FromHumanReadable(cls, string: Text):
    """Returns a new instance from a human-readable string.

    Args:
      string: An `unicode` value to initialize the object from.
    """
    raise NotImplementedError(
        "Class {} does not implement FromHumanReadable.".format(cls.__name__)
    )


@functools.total_ordering
class RDFBytes(RDFPrimitive):
  """An attribute which holds bytes."""

  protobuf_type = "bytes"

  def __init__(self, initializer=None):
    if initializer is None:
      super().__init__(b"")
    elif isinstance(initializer, bytes):
      super().__init__(initializer)
    elif isinstance(initializer, RDFBytes):
      super().__init__(initializer.AsBytes())
    else:
      message = "Unexpected initializer `{value}` of type {type}"
      raise TypeError(message.format(value=initializer, type=type(initializer)))

  @classmethod
  def FromSerializedBytes(cls, value: bytes):
    precondition.AssertType(value, bytes)
    return cls(value)

  @classmethod
  def FromWireFormat(cls, value: bytes):
    precondition.AssertType(value, bytes)
    return cls(value)

  @classmethod
  def FromHumanReadable(cls, string: Text):
    precondition.AssertType(string, Text)
    return cls(string.encode("utf-8"))

  def AsBytes(self):
    return self._value

  def SerializeToBytes(self):
    return self.AsBytes()

  def __str__(self) -> Text:
    return text.Hexify(self.AsBytes())

  def __hash__(self):
    return hash(self.AsBytes())

  def __lt__(self, other):
    if isinstance(other, self.__class__):
      return self.AsBytes() < other.AsBytes()
    else:
      return self.AsBytes() < other

  def __eq__(self, other):
    if isinstance(other, self.__class__):
      return self.AsBytes() == other.AsBytes()
    else:
      return self.AsBytes() == other

  def __len__(self):
    return len(self.AsBytes())


class RDFZippedBytes(RDFBytes):
  """Zipped bytes sequence."""

  def Uncompress(self):
    if self:
      return zlib.decompress(self._value)
    else:
      return b""


@functools.total_ordering
class RDFString(RDFPrimitive):
  """Represent a simple string."""

  protobuf_type = "string"

  # TODO(hanuszczak): Allow initializing from arbitrary `unicode`-able object.
  def __init__(self, initializer=None):
    if initializer is None:
      super().__init__("")
    if isinstance(initializer, RDFString):
      super().__init__(str(initializer))
    elif isinstance(initializer, bytes):
      super().__init__(initializer.decode("utf-8"))
    elif isinstance(initializer, Text):
      super().__init__(initializer)
    elif initializer is not None:
      message = "Unexpected initializer `%s` of type `%s`"
      message %= (initializer, type(initializer))
      raise TypeError(message)

  def format(self, *args, **kwargs):  # pylint: disable=invalid-name
    return self._value.format(*args, **kwargs)

  def split(self, *args, **kwargs):  # pylint: disable=invalid-name
    return self._value.split(*args, **kwargs)

  def __str__(self) -> Text:
    return self._value

  def __hash__(self):
    return hash(self._value)

  def __getitem__(self, item):
    return self._value.__getitem__(item)

  def __len__(self):
    return len(self._value)

  def __eq__(self, other):
    if isinstance(other, RDFString):
      return self._value == other._value  # pylint: disable=protected-access

    if isinstance(other, Text):
      return self._value == other

    # TODO(hanuszczak): Comparing `RDFString` and `bytes` should result in type
    # error. For now we allow it because too many tests still use non-unicode
    # string literals.
    if isinstance(other, bytes):
      return self._value.encode("utf-8") == other

    return NotImplemented

  def __lt__(self, other):
    if isinstance(other, RDFString):
      return self._value < other._value  # pylint: disable=protected-access

    if isinstance(other, Text):
      return self._value < other

    # TODO(hanuszczak): Comparing `RDFString` and `bytes` should result in type
    # error. For now we allow it because too many tests still use non-unicode
    # string literals.
    if isinstance(other, bytes):
      return self._value.encode("utf-8") < other

    return NotImplemented

  @classmethod
  def FromSerializedBytes(cls, value: bytes):
    precondition.AssertType(value, bytes)
    return cls(value)

  @classmethod
  def FromWireFormat(cls, value):
    return cls.FromHumanReadable(value)

  @classmethod
  def FromHumanReadable(cls, string: Text):
    precondition.AssertType(string, Text)
    return cls(string)

  def SerializeToBytes(self):
    return self._value.encode("utf-8")

  def SerializeToWireFormat(self):
    return self._value


# TODO(hanuszczak): This class should provide custom method for parsing from
# human readable strings (and arguably should not derive from `RDFBytes` at
# all).
class HashDigest(RDFBytes):
  """Binary hash digest with hex string representation."""

  protobuf_type = "bytes"

  def HexDigest(self) -> Text:
    return text.Hexify(self._value)

  def __str__(self) -> Text:
    return self.HexDigest()

  def __hash__(self):
    return hash(self._value)

  # TODO(hanuszczak): This is a terrible equality definition.
  def __eq__(self, other):
    if isinstance(other, HashDigest):
      return self._value == other._value  # pylint: disable=protected-access
    if isinstance(other, bytes):
      return self._value == other
    if isinstance(other, Text):
      return str(self) == other
    return NotImplemented

  def __ne__(self, other):
    return not self == other


@functools.total_ordering
class RDFInteger(RDFPrimitive):
  """Represent an integer."""

  protobuf_type = "integer"

  @staticmethod
  def IsNumeric(value):
    return isinstance(value, (int, float, RDFInteger))

  def __init__(self, initializer=None):
    if initializer is None:
      super().__init__(0)
    else:
      super().__init__(int(initializer))

  def SerializeToBytes(self) -> bytes:
    return str(self._value).encode("ascii")

  @classmethod
  def FromSerializedBytes(cls, value: bytes):
    precondition.AssertType(value, bytes)

    if value:
      return cls(int(value))
    else:
      return cls(0)

  @classmethod
  def FromHumanReadable(cls, string: Text):
    precondition.AssertType(string, Text)
    return cls(int(string))

  def __str__(self) -> Text:
    return str(self._value)

  @classmethod
  def FromWireFormat(cls, value):
    return cls(initializer=value)

  def SerializeToWireFormat(self):
    """Use varint to store the integer."""
    return self._value

  def __long__(self):
    return self._value

  def __int__(self):
    return self._value

  def __float__(self):
    return float(self._value)

  def __index__(self):
    return self._value

  def __lt__(self, other):
    return self._value < other

  def __and__(self, other):
    return self._value & other

  def __rand__(self, other):
    return self._value & other

  def __or__(self, other):
    return self._value | other

  def __ror__(self, other):
    return self._value | other

  def __add__(self, other):
    return self._value + other

  def __radd__(self, other):
    return self._value + other

  def __sub__(self, other):
    return self._value - other

  def __rsub__(self, other):
    return other - self._value

  def __mul__(self, other):
    return self._value * other

  def __rmul__(self, other):
    return self._value * other

  def __div__(self, other):
    return self._value.__div__(other)

  def __truediv__(self, other):
    return self._value.__truediv__(other)

  def __floordiv__(self, other):
    return self._value.__floordiv__(other)

  def __hash__(self):
    return hash(self._value)


@functools.total_ordering
class RDFDatetime(RDFPrimitive):
  """A date and time internally stored in MICROSECONDS."""

  converter = 1000000
  protobuf_type = "unsigned_integer"

  def __init__(self, initializer=None):
    if initializer is None:
      super().__init__(0)
    elif isinstance(initializer, RDFDatetime):
      super().__init__(
          initializer.AsMicrosecondsSinceEpoch() * self.converter // 1000000
      )
    elif isinstance(initializer, float):
      raise TypeError("Float initialization not supported.")
    elif isinstance(initializer, (RDFInteger, int)):
      super().__init__(int(initializer))
    else:
      raise InitializeError(
          "Unknown initializer for RDFDateTime: %s %s."
          % (type(initializer), initializer)
      )

  @classmethod
  def FromWireFormat(cls, value: int):
    return cls(initializer=value)

  def SerializeToWireFormat(self) -> int:
    """Use varint to store the integer."""
    return self._value

  def SerializeToBytes(self) -> bytes:
    return str(self._value).encode("ascii")

  @classmethod
  def FromSerializedBytes(cls, value: bytes):
    precondition.AssertType(value, bytes)

    if value:
      return cls(int(value))
    else:
      return cls(0)

  @classmethod
  def Now(cls):
    return cls(int(time.time() * cls.converter))

  def Format(self, fmt: Text) -> Text:
    """Return the value as a string formatted as per strftime semantics."""
    precondition.AssertType(fmt, Text)

    stime = time.gmtime(self._value / self.converter)
    return time.strftime(fmt, stime)

  def __str__(self) -> Text:
    """Return the date in human readable (UTC)."""
    # TODO: Display microseconds if applicable.
    return self.Format("%Y-%m-%d %H:%M:%S")

  def AsDatetime(self) -> datetime.datetime:
    """Returns the datetime as Python `datetime` without any timezone set."""
    return datetime.datetime.utcfromtimestamp(self._value / self.converter)

  def AsDatetimeUTC(self) -> datetime.datetime:
    """Returns the datetime as a Python `datetime` with UTC timezone."""
    return self.AsDatetime().replace(tzinfo=datetime.timezone.utc)

  def AsProtoTimestamp(self) -> timestamp_pb2.Timestamp:
    """Returns the datetime as standard Protocol Buffers timestamp message."""
    micros = self.AsMicrosecondsSinceEpoch()

    timestamp = timestamp_pb2.Timestamp()
    timestamp.seconds = micros // 1_000_000
    timestamp.nanos = (micros % 1_000_000) * 1_000
    return timestamp

  def AsSecondsSinceEpoch(self):
    return self._value // self.converter

  def AsMicrosecondsSinceEpoch(self):
    return self._value * 1000000 // self.converter

  @classmethod
  def FromSecondsSinceEpoch(cls, value):
    # Convert to int in case we get fractional seconds with higher
    # resolution than what this class supports.
    return cls(int(value * cls.converter))

  @classmethod
  def FromMicrosecondsSinceEpoch(cls, value):
    precondition.AssertType(value, int)
    return cls(value * cls.converter // 1000000)

  @classmethod
  def FromDatetime(cls, value):
    seconds = calendar.timegm(value.utctimetuple())
    return cls(
        (seconds * cls.converter)
        + (value.microsecond * cls.converter // 1000000)
    )

  @classmethod
  def FromProtoTimestamp(
      cls,
      timestamp: timestamp_pb2.Timestamp,
  ) -> "RDFDatetime":
    """Converts Protocol Buffers `Timestamp` instances to datetime objects.

    Args:
      timestamp: A Protocol Buffers `Timestamp` instance to convert.

    Returns:
      A corresponding RDF datetime object.
    """
    micros = timestamp.seconds * 1_000_000 + timestamp.nanos // 1_000
    return RDFDatetime.FromMicrosecondsSinceEpoch(micros)

  @classmethod
  def FromDate(cls, value):
    seconds = calendar.timegm(value.timetuple())
    return cls(seconds * cls.converter)

  @classmethod
  def Lerp(cls, t, start_time, end_time):
    """Interpolates linearly between two datetime values.

    Args:
      t: An interpolation "progress" value.
      start_time: A value for t = 0.
      end_time: A value for t = 1.

    Returns:
      An interpolated `RDFDatetime` instance.

    Raises:
      TypeError: If given time values are not instances of `RDFDatetime`.
      ValueError: If `t` parameter is not between 0 and 1.
    """
    if not (
        isinstance(start_time, RDFDatetime)
        and isinstance(end_time, RDFDatetime)
    ):
      raise TypeError("Interpolation of non-datetime values")

    if not 0.0 <= t <= 1.0:
      raise ValueError("Interpolation progress does not belong to [0.0, 1.0]")

    return cls(round((1 - t) * start_time._value + t * end_time._value))  # pylint: disable=protected-access

  def __add__(self, other):
    if isinstance(other, float):
      raise TypeError("Float initialization not supported.")
    elif isinstance(other, int):
      # Assume other is in seconds
      other_microseconds = int(other * self.converter)
      return self.__class__(self._value + other_microseconds)
    elif isinstance(other, (DurationSeconds, Duration)):
      self_us = self.AsMicrosecondsSinceEpoch()
      duration_us = other.microseconds
      return self.__class__.FromMicrosecondsSinceEpoch(self_us + duration_us)

    return NotImplemented

  def __sub__(self, other):
    if isinstance(other, float):
      raise TypeError("Float initialization not supported.")
    elif isinstance(other, int):
      # Assume other is in seconds
      other_microseconds = int(other * self.converter)
      return self.__class__(self._value - other_microseconds)
    elif isinstance(other, (DurationSeconds, Duration)):
      self_us = self.AsMicrosecondsSinceEpoch()
      duration_us = other.microseconds
      return self.__class__.FromMicrosecondsSinceEpoch(self_us - duration_us)
    elif isinstance(other, RDFDatetime):
      diff_us = (
          self.AsMicrosecondsSinceEpoch() - other.AsMicrosecondsSinceEpoch()
      )
      return Duration.From(diff_us, MICROSECONDS)

    return NotImplemented

  @classmethod
  def FromHumanReadable(cls, string: Text, eoy=False):
    """Parses a human readable string of a timestamp (in local time).

    Args:
      string: The string to parse.
      eoy: If True, sets the default value to the end of the year. Usually this
        method returns a timestamp where each field that is not present in the
        given string is filled with values from the date January 1st of the
        current year, midnight. Sometimes it makes more sense to compare against
        the end of a period so if eoy is set, the default values are copied from
        the 31st of December of the current year, 23:59h.

    Returns:
      A new instance based on the given string.
    """
    # TODO(hanuszczak): This method should accept only unicode literals.
    # TODO(hanuszczak): Date can come either as a single integer (which we
    # interpret as a timestamp) or as a really human readable thing such as
    # '2000-01-01 13:37'. This is less than ideal (since timestamps are not
    # really "human readable") and should be fixed in the future.
    try:
      return cls(int(string))
    except ValueError:
      pass

    # By default assume the time is given in UTC.
    # pylint: disable=g-tzinfo-datetime
    if eoy:
      default = datetime.datetime(
          time.gmtime().tm_year, 12, 31, 23, 59, tzinfo=dateutil.tz.tzutc()
      )
    else:
      default = datetime.datetime(
          time.gmtime().tm_year, 1, 1, 0, 0, tzinfo=dateutil.tz.tzutc()
      )
    # pylint: enable=g-tzinfo-datetime

    timestamp = parser.parse(string, default=default)

    raw = calendar.timegm(timestamp.utctimetuple()) * cls.converter
    return cls(raw)

  def Floor(self, interval):
    precondition.AssertType(interval, Duration)
    seconds = (
        self.AsSecondsSinceEpoch()
        // interval.ToInt(SECONDS)
        * interval.ToInt(SECONDS)
    )
    return self.FromSecondsSinceEpoch(seconds)

  def __hash__(self):
    return hash(self._value)

  def __eq__(self, other):
    if isinstance(other, RDFDatetime):
      return self.AsMicrosecondsSinceEpoch() == other.AsMicrosecondsSinceEpoch()
    else:
      return self._value == other

  def __lt__(self, other):
    if isinstance(other, RDFDatetime):
      return self.AsMicrosecondsSinceEpoch() < other.AsMicrosecondsSinceEpoch()
    else:
      return self._value < other

  def __int__(self):
    return self._value

  def AbsDiff(self, other: "RDFDatetime") -> "Duration":
    if self > other:
      return self - other
    return other - self


class RDFDatetimeSeconds(RDFDatetime):
  """A DateTime class which is stored in whole seconds."""

  converter = 1


# Constants used as time unit in Duration methods.
MICROSECONDS = 1
MILLISECONDS = 1000
SECONDS = 1000 * MILLISECONDS
MINUTES = 60 * SECONDS
HOURS = 60 * MINUTES
DAYS = 24 * HOURS
WEEKS = 7 * DAYS

_DURATION_RE = re.compile(r"(?P<number>\d+) ?(?P<unit>[a-z]{1,2})?")


@functools.total_ordering
class Duration(RDFPrimitive):
  """Absolute duration between instants in time with microsecond precision.

  The duration is stored as non-negative integer, guaranteeing microsecond
  precision up to MAX_UINT64 microseconds (584k years).
  """

  protobuf_type = "unsigned_integer"

  _DIVIDERS = dict((
      ("w", WEEKS),
      ("d", DAYS),
      ("h", HOURS),
      ("m", MINUTES),
      ("s", SECONDS),
      ("ms", MILLISECONDS),
      ("us", MICROSECONDS),
  ))

  def __init__(self, initializer=None):
    """Instantiates a new microsecond-based Duration.

    Args:
      initializer: Integer specifying microseconds, or another Duration to copy.
        If None, Duration will be set to 0. Given a negative integer, its
        absolute (positive) value will be stored.
    """
    if isinstance(initializer, Duration):
      if initializer.microseconds < 0:
        raise ValueError("Negative Duration (%s ms)" % initializer.microseconds)
      super().__init__(initializer.microseconds)
    elif isinstance(initializer, (int, RDFInteger)):
      if int(initializer) < 0:
        raise ValueError("Negative Duration (%s s)" % initializer)
      super().__init__(int(initializer))
    elif isinstance(initializer, Text):
      super().__init__(self._ParseText(initializer, default_unit=None))
    elif initializer is None:
      super().__init__(0)
    else:
      message = "Unsupported initializer `{value}` of type `{type}`"
      raise TypeError(message.format(value=initializer, type=type(initializer)))

  @classmethod
  def From(cls, value: Union[int, float], timeunit: int):
    """Returns a new Duration given a timeunit and value.

    Args:
      value: A number specifying the value of the duration.
      timeunit: A unit of time ranging from rdfvalue.MICROSECONDS to
        rdfvalue.WEEKS.

    Examples:
      >>> Duration.From(50, MICROSECONDS)
      <Duration 50 us>

      >>> Duration.From(120, SECONDS)
      <Duration 2 m>

    Returns:
      A new Duration.
    """
    return cls(int(timeunit * value))

  @classmethod
  def FromWireFormat(cls, value):
    precondition.AssertType(value, int)
    return cls(value)

  def SerializeToWireFormat(self):
    """See base class."""
    return self.microseconds

  @classmethod
  def FromSerializedBytes(cls, value: bytes):
    precondition.AssertType(value, bytes)

    if not value:
      return cls(0)

    try:
      raw = abs(int(value))
    except ValueError as e:
      raise DecodeError(e) from e

    return cls(raw)

  def SerializeToBytes(self) -> bytes:
    """See base class."""
    # Technically, equal to ascii encoding, since str(self._value) only contains
    # the digits 0-9.
    return str(self.microseconds).encode("utf-8")

  def __repr__(self):
    return "<{} {}>".format(type(self).__name__, self)

  def __str__(self) -> Text:
    if self._value == 0:
      return "0 us"
    for label, divider in self._DIVIDERS.items():
      if self._value % divider == 0:
        return "%d %s" % (self._value // divider, label)
    return "%d us" % self._value  # Make pytype happy.

  def __add__(self, other):
    if isinstance(other, Duration):
      return self.__class__.From(
          self.microseconds + other.microseconds, MICROSECONDS
      )
    else:
      return NotImplemented

  def __sub__(self, other):
    if isinstance(other, Duration):
      return self.__class__.From(
          self.microseconds - other.microseconds, MICROSECONDS
      )
    else:
      return NotImplemented

  def __mul__(self, other):
    if isinstance(other, int):
      return self.__class__.From(self.microseconds * other, MICROSECONDS)
    else:
      return NotImplemented

  def __rmul__(self, other):
    return self.__mul__(other)

  def __lt__(self, other):
    if isinstance(other, Duration):
      return self.microseconds < other.microseconds
    else:
      return NotImplemented

  def __eq__(self, other):
    if isinstance(other, Duration):
      return self.microseconds == other.microseconds
    else:
      return NotImplemented

  def __abs__(self):
    return self

  def ToInt(self, timeunit: int) -> int:
    """Returns the duration as truncated integer, converted to the time unit.

    All fractions are truncated. To preserve them, use `toFractional()`.

    Examples:
      >>> Duration.From(2, WEEKS).ToInt(DAYS)
      14

      >>> Duration.From(100, SECONDS).ToInt(SECONDS)
      100

      >>> Duration.From(6, DAYS).ToInt(WEEKS)
      0

    Args:
      timeunit: A unit of time ranging from rdfvalue.MICROSECONDS to
        rdfvalue.WEEKS.

    Returns:
      An integer, representing the duration in the specific unit, truncating
      fractions.
    """
    return self.microseconds // timeunit

  def ToFractional(self, timeunit: int) -> float:
    """Returns the duration as float, converted to the given time unit.

    Examples:
      >>> Duration.From(30, SECONDS).ToFractional(MINUTES)
      0.5

      >>> Duration.From(100, SECONDS).ToFractional(SECONDS)
      100.0

      >>> Duration.From(6, MINUTES).ToFractional(HOURS)
      0.1

    Args:
      timeunit: A unit of time ranging from rdfvalue.MICROSECONDS to
        rdfvalue.WEEKS.

    Returns:
      A float, representing the duration in the specific unit, including
      fractions.
    """
    return self.microseconds / timeunit

  def AsTimedelta(self) -> datetime.timedelta:
    """Returns a standard `timedelta` object corresponding to the duration."""
    return datetime.timedelta(microseconds=self.ToInt(MICROSECONDS))

  @property
  def microseconds(self):
    return self._value

  @classmethod
  def FromHumanReadable(cls, string: Text):
    """See base class."""
    return cls(cls._ParseText(string, default_unit=None))

  @classmethod
  def _ParseText(cls, string: Text, default_unit: Optional[Text]) -> int:
    """Parses a textual representation of a duration."""
    precondition.AssertType(string, Text)

    if not string:
      return 0

    matches = _DURATION_RE.match(string)
    if matches is None:
      raise ValueError("Could not parse duration {!r}.".format(string))

    number = int(matches.group("number"))
    unit_string = matches.group("unit")

    if unit_string is None:
      unit_string = default_unit

    try:
      unit_multiplier = cls._DIVIDERS[unit_string]
    except KeyError as ex:
      raise ValueError(
          "Invalid unit {!r} for duration in {!r}. Expected any of {}.".format(
              unit_string, string, ", ".join(cls._DIVIDERS)
          )
      ) from ex

    return number * unit_multiplier

  def Expiry(self, base_time=None):
    if base_time is None:
      base_time = RDFDatetime.Now()
    return base_time + self


class DurationSeconds(Duration):
  """Duration that is (de)serialized with second-precision.

  This class exists for compatibility purposes and to keep certain API fields
  simple. For most uses, please prefer `Duration` directly.
  """

  def __init__(self, initializer: Any = None):
    if isinstance(initializer, (int, RDFInteger)):
      initializer = int(initializer) * SECONDS
    elif isinstance(initializer, Text):
      initializer = self._ParseText(initializer, default_unit="s")
    super().__init__(initializer)

  def SerializeToBytes(self) -> bytes:
    """See base class."""
    return str(self.ToInt(SECONDS)).encode("utf-8")

  def SerializeToWireFormat(self):
    """See base class."""
    return self.ToInt(SECONDS)

  @classmethod
  def FromHumanReadable(cls, string: Text):
    precondition.AssertType(string, Text)
    return cls(string)

  def __str__(self) -> Text:
    if self.microseconds == 0:
      return "0 s"
    return super().__str__()

  @classmethod
  def From(cls, value: Union[int, float], timeunit: int):
    """See base class."""
    return cls(int(timeunit * value // SECONDS))


class ByteSize(RDFInteger):
  """A size for bytes allowing standard unit prefixes.

  We use the standard IEC 60027-2 A.2 and ISO/IEC 80000:
  Binary units (powers of 2): Ki, Mi, Gi
  SI units (powers of 10): k, m, g
  """

  protobuf_type = "unsigned_integer"

  DIVIDERS = dict((
      ("", 1),
      ("k", 1000),
      ("m", 1000**2),
      ("g", 1000**3),
      ("ki", 1024),
      ("mi", 1024**2),
      ("gi", 1024**3),
  ))

  REGEX = re.compile("^([0-9.]+) ?([kmgi]*)b?$", re.I)

  def __init__(self, initializer=None):
    if isinstance(initializer, ByteSize):
      super().__init__(initializer._value)  # pylint: disable=protected-access
    elif isinstance(initializer, str):
      super().__init__(self._ParseText(initializer))
    elif isinstance(initializer, (int, float)):
      super().__init__(initializer)
    elif isinstance(initializer, RDFInteger):
      super().__init__(int(initializer))
    elif initializer is None:
      super().__init__(0)
    else:
      raise InitializeError(
          "Unknown initializer for ByteSize: %s." % type(initializer)
      )

  def __str__(self):
    if self._value >= 1024**3:
      unit = "GiB"
      value = self._value / 1024**3
    elif self._value >= 1024**2:
      unit = "MiB"
      value = self._value / 1024**2
    elif self._value >= 1024:
      unit = "KiB"
      value = self._value / 1024
    else:
      return "{} B".format(self._value)

    return "{value:.1f} {unit}".format(value=value, unit=unit)

  @classmethod
  def FromHumanReadable(cls, string: Text):
    return cls(cls._ParseText(string))

  @classmethod
  def _ParseText(cls, string: Text) -> int:
    """Parses a textual representation of the number of bytes."""
    if not string:
      return 0

    match = cls.REGEX.match(string.strip().lower())
    if not match:
      raise DecodeError("Unknown specification for ByteSize %s" % string)

    multiplier = cls.DIVIDERS.get(match.group(2))
    if not multiplier:
      raise DecodeError("Invalid multiplier %s" % match.group(2))

    # The value may be represented as a float, but if it's not, don't lose
    # accuracy.
    value = match.group(1)
    if "." in value:
      value = float(value)
    else:
      value = int(value)

    return int(value * multiplier)


@functools.total_ordering
class RDFURN(RDFPrimitive):
  """An object to abstract URL manipulation."""

  protobuf_type = "string"

  # Careful when changing this value, this is hardcoded a few times in this
  # class for performance reasons.
  scheme = "aff4"

  def __init__(self, initializer=None):
    """Constructor.

    Args:
      initializer: A string or another RDFURN.
    """
    # This is a shortcut that is a bit faster than the standard way of
    # using the RDFValue constructor to make a copy of the class. For
    # RDFURNs that way is a bit slow since it would try to normalize
    # the path again which is not needed - it comes from another
    # RDFURN so it is already in the correct format.

    if initializer is None:
      super().__init__("")
      return

    if isinstance(initializer, RDFURN):
      # Make a direct copy of the other object
      super().__init__(cast(RDFURN, initializer).Path())
      return

    precondition.AssertType(initializer, (bytes, Text))

    if isinstance(initializer, bytes):
      initializer = initializer.decode("utf-8")

    super().__init__(self._Normalize(initializer))

  @classmethod
  def _Normalize(cls, string):
    if string.startswith("aff4:/"):
      string = string[5:]
    return utils.NormalizePath(string)

  @classmethod
  def FromSerializedBytes(cls, value: bytes):
    precondition.AssertType(value, bytes)
    return cls(value)

  @classmethod
  def FromWireFormat(cls, value):
    # TODO(hanuszczak): We should just assign the `self._value` here
    # instead of including all of the parsing magic since the data store values
    # should be normalized already. But sadly this is not the case and for now
    # we have to deal with unnormalized values as well.
    return cls(value)

  @classmethod
  def FromHumanReadable(cls, string: Text):
    precondition.AssertType(string, Text)
    return cls(string)

  def SerializeToBytes(self) -> bytes:
    return str(self).encode("utf-8")

  def SerializeToWireFormat(self) -> Text:
    return str(self)

  def Dirname(self):
    return posixpath.dirname(self._value)

  def Basename(self):
    return posixpath.basename(self.Path())

  def Add(self, path):
    """Add a relative stem to the current value and return a new RDFURN.

    If urn is a fully qualified URN, replace the current value with it.

    Args:
      path: A string containing a relative path.

    Returns:
       A new RDFURN that can be chained.

    Raises:
       ValueError: if the path component is not a string.
    """
    if not isinstance(path, str):
      raise ValueError(
          "Only strings should be added to a URN, not %s" % path.__class__
      )
    return self.__class__(utils.JoinPath(self._value, path))

  def __str__(self) -> Text:
    return "aff4:%s" % self._value

  # Required, because in Python 3 overriding `__eq__` nullifies `__hash__`.
  __hash__ = RDFPrimitive.__hash__

  def __eq__(self, other):
    if isinstance(other, str):
      other = self.__class__(other)

    elif other is None:
      return False

    elif not isinstance(other, RDFURN):
      return NotImplemented

    return self._value == other.Path()

  def __bool__(self):
    return bool(self._value)

  def __lt__(self, other):
    return self._value < other

  def Path(self):
    """Return the path of the urn."""
    return self._value

  def Split(self, count=None):
    """Returns all the path components.

    Args:
      count: If count is specified, the output will be exactly this many path
        components, possibly extended with the empty string. This is useful for
        tuple assignments without worrying about ValueErrors:  namespace, path =
        urn.Split(2)

    Returns:
      A list of path components of this URN.
    """
    if count:
      result = list(filter(None, self._value.split("/", count)))
      while len(result) < count:
        result.append("")

      return result

    else:
      return list(filter(None, self._value.split("/")))

  def RelativeName(self, volume):
    """Given a volume URN return the relative URN as a unicode string.

    We remove the volume prefix from our own.
    Args:
      volume: An RDFURN or fully qualified url string.

    Returns:
      A string of the url relative from the volume or None if our URN does not
      start with the volume prefix.
    """
    string_url = utils.SmartUnicode(self)
    volume_url = utils.SmartUnicode(volume)
    if string_url.startswith(volume_url):
      result = string_url[len(volume_url) :]
      # This must always return a relative path so we strip leading "/"s. The
      # result is always a unicode string.
      return result.lstrip("/")

    return None

  def __repr__(self):
    return "<%s>" % self


class Subject(RDFURN):
  """A pseudo attribute representing the subject of an AFF4 object."""


DEFAULT_FLOW_QUEUE = RDFURN("F")


class SessionID(RDFURN):
  """An rdfvalue object that represents a session_id."""

  def __init__(
      self,
      initializer=None,
      base="aff4:/flows",
      queue=DEFAULT_FLOW_QUEUE,
      flow_name=None,
  ):
    """Constructor.

    Args:
      initializer: A string or another RDFURN.
      base: The base namespace this session id lives in.
      queue: The queue to use.
      flow_name: The name of this flow or its random id.

    Raises:
      InitializeError: The given URN cannot be converted to a SessionID.
    """
    if initializer is None:
      # This SessionID is being constructed from scratch.
      if flow_name is None:
        flow_name = random.UInt32()

      if isinstance(flow_name, int):
        initializer = RDFURN(base).Add("%s:%X" % (queue.Basename(), flow_name))
      else:
        initializer = RDFURN(base).Add("%s:%s" % (queue.Basename(), flow_name))
    else:
      if isinstance(initializer, RDFURN):
        try:
          self.ValidateID(initializer.Basename())
        except ValueError as e:
          raise InitializeError(
              "Invalid URN for SessionID: %s, %s" % (initializer, e)
          ) from e

    super().__init__(initializer=initializer)

  def Add(self, path):
    # Adding to a SessionID results in a normal RDFURN.
    return RDFURN(self).Add(path)

  @classmethod
  def ValidateID(cls, id_str):
    # This check is weaker than it could be because we allow queues called
    # "DEBUG-user1" and IDs like "TransferStore". We also have to allow
    # flows session ids like H:123456:hunt.
    allowed_re = re.compile(r"^[-0-9a-zA-Z]+(:[0-9a-zA-Z]+){0,2}$")
    if not allowed_re.match(id_str):
      raise ValueError("Invalid SessionID: %s" % id_str)


# TODO(hanuszczak): Remove this class.
class FlowSessionID(SessionID):
  pass
