#!/usr/bin/ruby

#
# this script will run the unit tests in ../tests/*.rb.
#
# Tests for the platform differ from traditional ruby unit tests in a few ways:
#
# (1) at the end of every test function, you should call 'pass()'
# (2) you can specify test dependencies by calling depends_on("TestFirst") in the test class definition.
# (3) test functions are always run in alphabetical order.
# (4) any halt or error will stop the testing unless --continue is specified.
#

require 'minitest/unit'
require 'yaml'
require 'tsort'
require 'net/http'

##
## EXIT CODES
##

EXIT_CODES = {
  :success => 0,
  :warning => 1,
  :failure => 2,
  :error => 3
}

def bail(code, msg=nil)
  puts msg if msg
  if code.is_a? Symbol
    exit(EXIT_CODES[code])
  else
    exit(code)
  end
end

##
## EXCEPTIONS
##

# this class is raised if a test file wants to be skipped entirely.
class SkipTest < Exception
end

# raised if --no-continue and there is an error
class TestError < Exception
end

# raised if --no-continue and there is a failure
class TestFailure < Exception
end

##
## CUSTOM UNIT TEST CLASS
##

#
# Our custom unit test class. All tests should be subclasses of this.
#
class LeapTest < MiniTest::Unit::TestCase
  class Pass < MiniTest::Assertion
  end

  def initialize(name)
    super(name)
    io # << calling this will suppress the marching ants
  end

  #
  # Test class dependencies
  #
  def self.depends_on(*class_names)
    @dependencies ||= []
    @dependencies += class_names
  end
  def self.dependencies
    @dependencies || []
  end

  #
  # returns all the test classes, sorted in dependency order.
  #
  def self.test_classes
    classes = ObjectSpace.each_object(Class).select {|test_class|
      test_class.ancestors.include?(self)
    }
    return TestDependencyGraph.new(classes).sorted
  end

  def self.tests
    self.instance_methods.grep(/^test_/).sort
  end

  #
  # The default pass just does an `assert true`. In our case, we want to make the passes more explicit.
  #
  def pass
    raise LeapTest::Pass
  end

  #
  # the default fail() is part of the kernel and it just throws a runtime exception. for tests,
  # we want the same behavior as assert(false)
  #
  def fail(msg=nil)
    assert(false, msg)
  end

  def warn(*msg)
    method_name = caller.first.split('`').last.gsub(/(block in |')/,'')
    MiniTest::Unit.runner.warn(self.class, method_name, msg.join("\n"))
  end

  # Always runs test methods within a test class in alphanumeric order
  #
  def self.test_order
    :alpha
  end

  #
  # attempts a http GET on the url, yields |body, response, error|
  #
  def get(url, params=nil)
    uri = URI(url)
    if params
      uri.query = URI.encode_www_form(params)
    end
    response = Net::HTTP.get_response(uri)
    if response.is_a?(Net::HTTPSuccess)
      yield response.body, response, nil
    else
      yield nil, response, nil
    end
  rescue => exc
    yield nil, nil, exc
  end

  def assert_get(url, params=nil, options=nil)
    options ||= {}
    get(url, params) do |body, response, error|
      if body
        yield body if block_given?
      elsif response
        fail ["Expected a 200 status code from #{url}, but got #{response.code} instead.", options[:error_msg]].compact.join("\n")
      else
        fail ["Expected a response from #{url}, but got \"#{error}\" instead.", options[:error_msg]].compact.join("\n")
      end
    end
  end

  #
  # test if a socket can be connected to
  #

  #
  # tcp connection helper with timeout
  #
  def try_tcp_connect(host, port, timeout = 5)
    addr     = Socket.getaddrinfo(host, nil)
    sockaddr = Socket.pack_sockaddr_in(port, addr[0][3])

    Socket.new(Socket.const_get(addr[0][0]), Socket::SOCK_STREAM, 0).tap do |socket|
      socket.setsockopt(Socket::IPPROTO_TCP, Socket::TCP_NODELAY, 1)
      begin
        socket.connect_nonblock(sockaddr)
      rescue IO::WaitReadable
        if IO.select([socket], nil, nil, timeout) == nil
          raise "Connection timeout"
        else
          socket.connect_nonblock(sockaddr)
        end
      rescue IO::WaitWritable
        if IO.select(nil, [socket], nil, timeout) == nil
          raise "Connection timeout"
        else
          socket.connect_nonblock(sockaddr)
        end
      end
      return socket
    end
  end

  def try_tcp_write(socket, timeout = 5)
    begin
      socket.write_nonblock("\0")
    rescue IO::WaitReadable
      if IO.select([socket], nil, nil, timeout) == nil
        raise "Write timeout"
      else
        retry
      end
    rescue IO::WaitWritable
      if IO.select(nil, [socket], nil, timeout) == nil
        raise "Write timeout"
      else
        retry
      end
    end
  end

  def try_tcp_read(socket, timeout = 5)
    begin
      socket.read_nonblock(1)
    rescue IO::WaitReadable
      if IO.select([socket], nil, nil, timeout) == nil
        raise "Read timeout"
      else
        retry
      end
    rescue IO::WaitWritable
      if IO.select(nil, [socket], nil, timeout) == nil
        raise "Read timeout"
      else
        retry
      end
    end
  end

  def assert_tcp_socket(host, port, msg=nil)
    begin
      socket = try_tcp_connect(host, port, 1)
      #try_tcp_write(socket,1)
      #try_tcp_read(socket,1)
    rescue StandardError => exc
      fail ["Failed to open socket #{host}:#{port}", exc].join("\n")
    ensure
      socket.close if socket
    end
  end

  #
  # Matches the regexp in the file, and returns the first matched string (or fails if no match).
  #
  def file_match(filename, regexp)
    if match = File.read(filename).match(regexp)
      match.captures.first
    else
      fail "Regexp #{regexp.inspect} not found in file #{filename.inspect}."
    end
  end

  #
  # Matches the regexp in the file, and returns array of matched strings (or fails if no match).
  #
  def file_matches(filename, regexp)
    if match = File.read(filename).match(regexp)
      match.captures
    else
      fail "Regexp #{regexp.inspect} not found in file #{filename.inspect}."
    end
  end

  #
  # checks to make sure the given property path exists in $node (e.g. hiera.yaml)
  # and returns the value
  #
  def assert_property(property)
    latest = $node
    property.split('.').each do |segment|
      latest = latest[segment]
      fail "Required node property `#{property}` is missing." if latest.nil?
    end
    return latest
  end

  #
  # works like pgrep command line
  # return an array of hashes like so [{:pid => "1234", :process => "ls"}]
  #
  def pgrep(match)
    output = `pgrep --full --list-name '#{match}'`
    output.each_line.map{|line|
      pid = line.split(' ')[0]
      process = line.gsub(/(#{pid} |\n)/, '')
      if process =~ /pgrep --full --list-name/
        nil
      else
        {:pid => pid, :process => process}
      end
    }.compact
  end
end

def assert_running(process)
  assert pgrep(process).any?, "No running process for #{process}"
end

#
# runs the specified command, failing on a non-zero exit status.
#
def assert_run(command)
  output = `#{command}`
  if $?.exitstatus != 0
    fail "Error running `#{command}`:\n#{output}"
  end
end

#
# Custom test runner in order to modify the output.
#
class LeapRunner < MiniTest::Unit

  attr_accessor :passes, :warnings

  def initialize
    @passes = 0
    @warnings = 0
    super
  end

  #
  # call stack:
  #   MiniTest::Unit.new.run
  #     MiniTest::Unit.runner
  #       LeapTest._run
  #
  def _run args = []
    if $pinned_test_class
      suites = [$pinned_test_class]
      if $pinned_test_method
        options.merge!(:filter => $pinned_test_method.to_s)
      end
    else
      suites = LeapTest.send "test_suites"
      suites = TestDependencyGraph.new(suites).sorted
    end
    output.sync = true
    results = _run_suites(suites, :test)
    @test_count      = results.inject(0) { |sum, (tc, _)| sum + tc }
    @assertion_count = results.inject(0) { |sum, (_, ac)| sum + ac }
    status
    return exit_code()
  rescue Interrupt
    bail :error, 'Tests halted on interrupt.'
  rescue TestFailure
    bail :failure, 'Tests halted on failure (because of --no-continue).'
  rescue TestError
    bail :error, 'Tests halted on error (because of --no-continue).'
  end

  #
  # override puke to change what prints out.
  #
  def puke(klass, meth, e)
    case e
      when MiniTest::Skip then
        @skips += 1
        #if @verbose
          report_line("SKIP", klass, meth, e, e.message)
        #end
      when LeapTest::Pass then
        @passes += 1
        report_line("PASS", klass, meth)
      when MiniTest::Assertion then
        @failures += 1
        report_line("FAIL", klass, meth, e, e.message)
        if $halt_on_failure
          raise TestFailure.new
        end
      else
        @errors += 1
        bt = MiniTest::filter_backtrace(e.backtrace).join "\n"
        report_line("ERROR", klass, meth, e, "#{e.class}: #{e.message}\n#{bt}")
        if $halt_on_failure
          raise TestError.new
        end
    end
    return "" # disable the marching ants
  end

  #
  # override default status summary
  #
  def status(io = self.output)
    if $output_format == :human
      format = "%d tests: %d passes, %d skips, %d warnings, %d failures, %d errors"
      output.puts format % [test_count, passes, skips, warnings, failures, errors]
    end
  end

  #
  # return an appropriate exit_code symbol
  #
  def exit_code
    if @errors > 0
      :error
    elsif @failures > 0
      :failure
    elsif @warnings > 0
      :warning
    else
      :success
    end
  end

  #
  # returns a string for a PASS, SKIP, or FAIL error
  #
  def report_line(prefix, klass, meth, e=nil, message=nil)
    msg_txt = nil
    if message
      message = message.sub(/http:\/\/([a-z_]+):([a-zA-Z0-9_]+)@/, "http://\\1:password@")
      if $output_format == :human
        indent = "\n  "
        msg_txt = indent + message.split("\n").join(indent)
      else
        msg_txt = message.gsub("\n", ' ')
      end
    end

    if $output_format == :human
      if e && msg_txt
        output.puts "#{prefix}: #{readable(klass.name)} > #{readable(meth)} [#{File.basename(location(e))}]:#{msg_txt}"
      elsif msg_txt
        output.puts "#{prefix}: #{readable(klass.name)} > #{readable(meth)}:#{msg_txt}"
      else
        output.puts "#{prefix}: #{readable(klass.name)} > #{readable(meth)}"
      end
                      # I don't understand at all why, but adding a very tiny sleep here will
      sleep(0.0001)   # keep lines from being joined together by the logger. output.flush doesn't.
    elsif $output_format == :checkmk
      code = CHECKMK_CODES[prefix]
      msg_txt ||= "Success" if prefix == "PASS"
      if e && msg_txt
        output.puts "#{code} #{klass.name}/#{machine_readable(meth)} - [#{File.basename(location(e))}]:#{msg_txt}"
      elsif msg_txt
        output.puts "#{code} #{klass.name}/#{machine_readable(meth)} - #{msg_txt}"
      else
        output.puts "#{code} #{klass.name}/#{machine_readable(meth)} - no message"
      end
    end
  end

  #
  # a new function used by TestCase to report warnings.
  #
  def warn(klass, method_name, msg)
    @warnings += 1
    report_line("WARN", klass, method_name, nil, msg)
  end

  private

  CHECKMK_CODES = {"PASS" => 0, "SKIP" => 1, "FAIL" => 2, "ERROR" => 3}

  #
  # Converts snake_case and CamelCase to something more pleasant for humans to read.
  #
  def readable(str)
    str.
    gsub(/_/, ' ').
    sub(/^test (\d* )?/i, '')
  end

  def machine_readable(str)
    str.sub(/^test_(\d+_)?/i, '')
  end

end

##
## Dependency resolution
## Use a topographical sort to manage test dependencies
##

class TestDependencyGraph
  include TSort

  def initialize(test_classes)
    @dependencies = {}  # each key is a test class name, and the values
                        # are arrays of test class names that the key depends on.
    test_classes.each do |test_class|
      @dependencies[test_class.name] = test_class.dependencies
    end
  end

  def tsort_each_node(&block)
    @dependencies.each_key(&block)
  end

  def tsort_each_child(test_class_name, &block)
    if @dependencies[test_class_name]
      @dependencies[test_class_name].each(&block)
    else
      puts "ERROR: bad dependency, no such class `#{test_class_name}`"
      bail :error
    end
  end

  def sorted
    self.tsort.collect {|class_name|
      Kernel.const_get(class_name)
    }
  end
end

##
## COMMAND LINE ACTIONS
##

def die(test, msg)
  if $output_format == :human
    puts "ERROR in test `#{test}`: #{msg}"
  elsif $output_format == :checkmk
    puts "3 #{test} - #{msg}"
  end
  bail :error
end

def print_help
  puts ["USAGE: run_tests [OPTIONS]",
       "  --continue       Don't halt on an error, but continue to the next test.",
       "  --checkmk        Print test results in checkmk format (must come before --test).",
       "  --test TEST      Run only the test with name TEST.",
       "  --list-tests     Prints the names of all available tests and exit.",
       "  --retry COUNT    If the tests don't pass, retry COUNT additional times (default is zero)",
       "  --wait SECONDS   Wait for SECONDS between retries (default is 5)"].join("\n")
  exit(0)
end

def list_tests
  LeapTest.test_classes.each do |test_class|
    test_class.tests.each do |test|
      puts test_class.name + "/" + test.to_s.sub(/^test_(\d+_)?/, '')
    end
  end
  exit(0)
end

def pin_test_name(name)
  test_class, test_name = name.split('/')
  $pinned_test_class = LeapTest.test_classes.detect{|c| c.name == test_class}
  unless $pinned_test_class
    die name, "there is no test class `#{test_class}`"
  end
  if test_name
    $pinned_test_method = $pinned_test_class.tests.detect{|m| m.to_s =~ /^test_(\d+_)?#{Regexp.escape(test_name)}$/}
    unless $pinned_test_method
      die name, "there is no test `#{test_name}` in class `#{test_class}`"
    end
  end
end

#
# run the tests, multiple times if `--retry` and not all tests were successful.
#
def run_tests
  exit_code = nil
  run_count = $retry ? $retry + 1 : 1
  run_count.times do |i|
    MiniTest::Unit.runner = LeapRunner.new
    exit_code = MiniTest::Unit.new.run
    if !$retry || exit_code == :success
      break
    elsif i != run_count-1
      sleep $wait
    end
  end
  bail exit_code
end

##
## MAIN
##

def main
  # load node data from hiera file
  if File.exists?('/etc/leap/hiera.yaml')
    $node = YAML.load_file('/etc/leap/hiera.yaml')
  else
    $node = {"services" => [], "dummy" => true}
  end

  # load all test classes
  this_file = File.symlink?(__FILE__) ? File.readlink(__FILE__) : __FILE__
  Dir[File.expand_path('../../tests/white-box/*.rb', this_file)].each do |test_file|
    begin
      require test_file
    rescue SkipTest
    end
  end

  # parse command line options
  $halt_on_failure = true
  $output_format = :human
  $retry = false
  $wait = 5
  loop do
    case ARGV[0]
      when '--continue' then ARGV.shift; $halt_on_failure = false;
      when '--checkmk' then ARGV.shift; $output_format = :checkmk; $halt_on_failure = false
      when '--help' then print_help
      when '--test' then ARGV.shift; pin_test_name(ARGV.shift)
      when '--list-tests' then list_tests
      when '--retry' then ARGV.shift; $retry = ARGV.shift.to_i
      when '--wait' then ARGV.shift; $wait = ARGV.shift.to_i
      else break
    end
  end
  run_tests
end

main()