# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Contains a custom unittest loader and test suite.

Implements:
- A test filter based on the JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS
  environment variables.
- A test suite that runs tests in parallel using threads if JAX_TEST_NUM_THREADS
  is >= 1.
- Test decorators that mark a test case or test class as thread-hostile.
"""

from __future__ import annotations

from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
import logging
import os
import re
import threading
import time
import unittest

from absl.testing import absltest
from jax._src import config
from jax._src import test_warning_util
from jax._src import util

logger = logging.getLogger(__name__)


_TEST_TARGETS = config.string_flag(
  'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
  'Regular expression specifying which tests to run, called via re.search on '
  'the test name. If empty or unspecified, run all tests.'
)

_EXCLUDE_TEST_TARGETS = config.string_flag(
  'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
  'Regular expression specifying which tests NOT to run, called via re.search '
  'on the test name. If empty or unspecified, run all tests.'
)

TEST_NUM_THREADS = config.int_flag(
    'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
    help='Number of threads to use for running tests. 0 means run everything '
    'in the main thread. Using > 1 thread is experimental.'
)

# We use a reader-writer lock to protect test execution. Tests that may run in
# parallel acquire a read lock; tests that are not thread-safe acquire a write
# lock.
_test_rwlock = util.Mutex()

def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
  if getattr(test.__class__, "thread_hostile", False):
    _test_rwlock.writer_lock()
    try:
      test(result)  # type: ignore
    finally:
      _test_rwlock.writer_unlock()
  else:
    _test_rwlock.reader_lock()
    try:
      test(result)  # type: ignore
    finally:
      _test_rwlock.reader_unlock()


@contextmanager
def thread_unsafe_test(condition: bool = True):
  """Decorator for tests that are not thread-safe.

  Args:
    condition: If True, mark the test as thread-unsafe. If False, the test
      runs normally without acquiring the write lock. Defaults to True.

  Note: this decorator (naturally) only applies to what it wraps, not to, say,
  code in separate setUp() or tearDown() methods.
  """
  if TEST_NUM_THREADS.value <= 0 or not condition:
    yield
    return

  _test_rwlock.assert_reader_held()
  _test_rwlock.reader_unlock()
  _test_rwlock.writer_lock()
  try:
    yield
  finally:
    _test_rwlock.writer_unlock()
    _test_rwlock.reader_lock()


def thread_unsafe_test_class(condition: bool = True):
  """Decorator that marks a TestCase class as thread-hostile.

  Args:
    condition: If True, mark the test class as thread-hostile. If False, the
      test class runs normally. Defaults to True.
  """
  def f(klass):
    assert issubclass(klass, unittest.TestCase), type(klass)
    klass.thread_hostile = condition
    return klass
  return f


class ThreadSafeTestResult:
  """
  Wraps a TestResult to make it thread safe.

  We do this by accumulating API calls and applying them in a batch under a
  lock at the conclusion of each test case.

  We duck type instead of inheriting from TestResult because we aren't actually
  a perfect implementation of TestResult, and would rather get a loud error
  for things we haven't implemented.
  """
  def __init__(self, lock: threading.Lock, result: unittest.TestResult):
    self.lock = lock
    self.test_result = result
    self.actions: list[Callable[[], None]] = []

  def startTest(self, test: unittest.TestCase):
    logger.info("Test start: %s", test.id())
    self.start_time = time.time()

  def stopTest(self, test: unittest.TestCase):
    logger.info("Test stop: %s", test.id())
    stop_time = time.time()
    with self.lock:
      # If test_result is an ABSL _TextAndXMLTestResult we override how it gets
      # the time. This affects the timing that shows up in the XML output
      # consumed by CI.
      time_getter = getattr(self.test_result, "time_getter", None)
      try:
        self.test_result.time_getter = lambda: self.start_time
        self.test_result.startTest(test)
        for callback in self.actions:
          callback()
        self.test_result.time_getter = lambda: stop_time
        self.test_result.stopTest(test)
      finally:
        if time_getter is not None:
          self.test_result.time_getter = time_getter

  def addSuccess(self, test: unittest.TestCase):
    self.actions.append(lambda: self.test_result.addSuccess(test))

  def addSkip(self, test: unittest.TestCase, reason: str):
    self.actions.append(lambda: self.test_result.addSkip(test, reason))

  def addError(self, test: unittest.TestCase, err):
    self.actions.append(lambda: self.test_result.addError(test, err))

  def addFailure(self, test: unittest.TestCase, err):
    self.actions.append(lambda: self.test_result.addFailure(test, err))

  def addExpectedFailure(self, test: unittest.TestCase, err):
    self.actions.append(lambda: self.test_result.addExpectedFailure(test, err))

  def addDuration(self, test: unittest.TestCase, elapsed):
    self.actions.append(lambda: self.test_result.addDuration(test, elapsed))


class JaxTestSuite(unittest.TestSuite):
  """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1.

  Caution: this test suite does not run setUpClass or setUpModule methods if
  thread parallelism is enabled.
  """

  def __init__(self, suite: unittest.TestSuite):
    super().__init__(list(suite))

  def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult:
    if TEST_NUM_THREADS.value <= 0:
      return super().run(result)

    test_warning_util.install_threadsafe_warning_handlers()

    executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
    lock = threading.Lock()
    futures = []

    def run_test(test):
      """Recursively runs tests in a test suite or test case."""
      if isinstance(test, unittest.TestSuite):
        for subtest in test:
          run_test(subtest)
      else:
        test_result = ThreadSafeTestResult(lock, result)
        futures.append(executor.submit(_run_one_test, test, test_result))

    with executor:
      run_test(self)
      for future in futures:
        future.result()

    return result


class JaxTestLoader(absltest.TestLoader):
  suiteClass = JaxTestSuite

  def getTestCaseNames(self, testCaseClass):
    names = super().getTestCaseNames(testCaseClass)
    if _TEST_TARGETS.value:
      pattern = re.compile(_TEST_TARGETS.value)
      names = [name for name in names
               if pattern.search(f"{testCaseClass.__name__}.{name}")]
    if _EXCLUDE_TEST_TARGETS.value:
      pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
      names = [name for name in names
               if not pattern.search(f"{testCaseClass.__name__}.{name}")]
    return names
