# Copyright (C) 2020 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittests for the git_trace2_event_log.py module.""" import json import os import socket import tempfile import threading import unittest from unittest import mock import git_trace2_event_log import platform_utils def serverLoggingThread(socket_path, server_ready, received_traces): """Helper function to receive logs over a Unix domain socket. Appends received messages on the provided socket and appends to received_traces. Args: socket_path: path to a Unix domain socket on which to listen for traces server_ready: a threading.Condition used to signal to the caller that this thread is ready to accept connections received_traces: a list to which received traces will be appended (after decoding to a utf-8 string). """ platform_utils.remove(socket_path, missing_ok=True) data = b"" with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: sock.bind(socket_path) sock.listen(0) with server_ready: server_ready.notify() with sock.accept()[0] as conn: while True: recved = conn.recv(4096) if not recved: break data += recved received_traces.extend(data.decode("utf-8").splitlines()) class EventLogTestCase(unittest.TestCase): """TestCase for the EventLog module.""" PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID" PARENT_SID_VALUE = "parent_sid" SELF_SID_REGEX = r"repo-\d+T\d+Z-.*" FULL_SID_REGEX = rf"^{PARENT_SID_VALUE}/{SELF_SID_REGEX}" def setUp(self): """Load the event_log module every time.""" self._event_log_module = None # By default we initialize with the expected case where # repo launches us (so GIT_TRACE2_PARENT_SID is set). env = { self.PARENT_SID_KEY: self.PARENT_SID_VALUE, } self._event_log_module = git_trace2_event_log.EventLog(env=env) self._log_data = None def verifyCommonKeys( self, log_entry, expected_event_name=None, full_sid=True ): """Helper function to verify common event log keys.""" self.assertIn("event", log_entry) self.assertIn("sid", log_entry) self.assertIn("thread", log_entry) self.assertIn("time", log_entry) # Do basic data format validation. if expected_event_name: self.assertEqual(expected_event_name, log_entry["event"]) if full_sid: self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX) else: self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX) self.assertRegex( log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+\+00:00$" ) def readLog(self, log_path): """Helper function to read log data into a list.""" log_data = [] with open(log_path, mode="rb") as f: for line in f: log_data.append(json.loads(line)) return log_data def remove_prefix(self, s, prefix): """Return a copy string after removing |prefix| from |s|, if present or the original string.""" if s.startswith(prefix): return s[len(prefix) :] else: return s def test_initial_state_with_parent_sid(self): """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent.""" self.assertRegex(self._event_log_module.full_sid, self.FULL_SID_REGEX) def test_initial_state_no_parent_sid(self): """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set.""" # Setup an empty environment dict (no parent sid). self._event_log_module = git_trace2_event_log.EventLog(env={}) self.assertRegex(self._event_log_module.full_sid, self.SELF_SID_REGEX) def test_version_event(self): """Test 'version' event data is valid. Verify that the 'version' event is written even when no other events are addded. Expected event log: """ with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: log_path = self._event_log_module.Write(path=tempdir) self._log_data = self.readLog(log_path) # A log with no added events should only have the version entry. self.assertEqual(len(self._log_data), 1) version_event = self._log_data[0] self.verifyCommonKeys(version_event, expected_event_name="version") # Check for 'version' event specific fields. self.assertIn("evt", version_event) self.assertIn("exe", version_event) # Verify "evt" version field is a string. self.assertIsInstance(version_event["evt"], str) def test_start_event(self): """Test and validate 'start' event data is valid. Expected event log: """ self._event_log_module.StartEvent() with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: log_path = self._event_log_module.Write(path=tempdir) self._log_data = self.readLog(log_path) self.assertEqual(len(self._log_data), 2) start_event = self._log_data[1] self.verifyCommonKeys(self._log_data[0], expected_event_name="version") self.verifyCommonKeys(start_event, expected_event_name="start") # Check for 'start' event specific fields. self.assertIn("argv", start_event) self.assertTrue(isinstance(start_event["argv"], list)) def test_exit_event_result_none(self): """Test 'exit' event data is valid when result is None. We expect None result to be converted to 0 in the exit event data. Expected event log: """ self._event_log_module.ExitEvent(None) with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: log_path = self._event_log_module.Write(path=tempdir) self._log_data = self.readLog(log_path) self.assertEqual(len(self._log_data), 2) exit_event = self._log_data[1] self.verifyCommonKeys(self._log_data[0], expected_event_name="version") self.verifyCommonKeys(exit_event, expected_event_name="exit") # Check for 'exit' event specific fields. self.assertIn("code", exit_event) # 'None' result should convert to 0 (successful) return code. self.assertEqual(exit_event["code"], 0) def test_exit_event_result_integer(self): """Test 'exit' event data is valid when result is an integer. Expected event log: """ self._event_log_module.ExitEvent(2) with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: log_path = self._event_log_module.Write(path=tempdir) self._log_data = self.readLog(log_path) self.assertEqual(len(self._log_data), 2) exit_event = self._log_data[1] self.verifyCommonKeys(self._log_data[0], expected_event_name="version") self.verifyCommonKeys(exit_event, expected_event_name="exit") # Check for 'exit' event specific fields. self.assertIn("code", exit_event) self.assertEqual(exit_event["code"], 2) def test_command_event(self): """Test and validate 'command' event data is valid. Expected event log: """ name = "repo" subcommands = ["init" "this"] self._event_log_module.CommandEvent( name="repo", subcommands=subcommands ) with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: log_path = self._event_log_module.Write(path=tempdir) self._log_data = self.readLog(log_path) self.assertEqual(len(self._log_data), 2) command_event = self._log_data[1] self.verifyCommonKeys(self._log_data[0], expected_event_name="version") self.verifyCommonKeys(command_event, expected_event_name="command") # Check for 'command' event specific fields. self.assertIn("name", command_event) self.assertIn("subcommands", command_event) self.assertEqual(command_event["name"], name) self.assertEqual(command_event["subcommands"], subcommands) def test_def_params_event_repo_config(self): """Test 'def_params' event data outputs only repo config keys. Expected event log: """ config = { "git.foo": "bar", "repo.partialclone": "true", "repo.partialclonefilter": "blob:none", } self._event_log_module.DefParamRepoEvents(config) with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: log_path = self._event_log_module.Write(path=tempdir) self._log_data = self.readLog(log_path) self.assertEqual(len(self._log_data), 3) def_param_events = self._log_data[1:] self.verifyCommonKeys(self._log_data[0], expected_event_name="version") for event in def_param_events: self.verifyCommonKeys(event, expected_event_name="def_param") # Check for 'def_param' event specific fields. self.assertIn("param", event) self.assertIn("value", event) self.assertTrue(event["param"].startswith("repo.")) def test_def_params_event_no_repo_config(self): """Test 'def_params' event data won't output non-repo config keys. Expected event log: """ config = { "git.foo": "bar", "git.core.foo2": "baz", } self._event_log_module.DefParamRepoEvents(config) with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: log_path = self._event_log_module.Write(path=tempdir) self._log_data = self.readLog(log_path) self.assertEqual(len(self._log_data), 1) self.verifyCommonKeys(self._log_data[0], expected_event_name="version") def test_data_event_config(self): """Test 'data' event data outputs all config keys. Expected event log: """ config = { "git.foo": "bar", "repo.partialclone": "false", "repo.syncstate.superproject.hassuperprojecttag": "true", "repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"], } prefix_value = "prefix" self._event_log_module.LogDataConfigEvents(config, prefix_value) with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: log_path = self._event_log_module.Write(path=tempdir) self._log_data = self.readLog(log_path) self.assertEqual(len(self._log_data), 5) data_events = self._log_data[1:] self.verifyCommonKeys(self._log_data[0], expected_event_name="version") for event in data_events: self.verifyCommonKeys(event) # Check for 'data' event specific fields. self.assertIn("key", event) self.assertIn("value", event) key = event["key"] key = self.remove_prefix(key, f"{prefix_value}/") value = event["value"] self.assertEqual( self._event_log_module.GetDataEventName(value), event["event"] ) self.assertTrue(key in config and value == config[key]) def test_error_event(self): """Test and validate 'error' event data is valid. Expected event log: """ msg = "invalid option: --cahced" fmt = "invalid option: %s" self._event_log_module.ErrorEvent(msg, fmt) with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: log_path = self._event_log_module.Write(path=tempdir) self._log_data = self.readLog(log_path) self.assertEqual(len(self._log_data), 2) error_event = self._log_data[1] self.verifyCommonKeys(self._log_data[0], expected_event_name="version") self.verifyCommonKeys(error_event, expected_event_name="error") # Check for 'error' event specific fields. self.assertIn("msg", error_event) self.assertIn("fmt", error_event) self.assertEqual(error_event["msg"], f"RepoErrorEvent:{msg}") self.assertEqual(error_event["fmt"], f"RepoErrorEvent:{fmt}") def test_write_with_filename(self): """Test Write() with a path to a file exits with None.""" self.assertIsNone(self._event_log_module.Write(path="path/to/file")) def test_write_with_git_config(self): """Test Write() uses the git config path when 'git config' call succeeds.""" with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: with mock.patch.object( self._event_log_module, "_GetEventTargetPath", return_value=tempdir, ): self.assertEqual( os.path.dirname(self._event_log_module.Write()), tempdir ) def test_write_no_git_config(self): """Test Write() with no git config variable present exits with None.""" with mock.patch.object( self._event_log_module, "_GetEventTargetPath", return_value=None ): self.assertIsNone(self._event_log_module.Write()) def test_write_non_string(self): """Test Write() with non-string type for |path| throws TypeError.""" with self.assertRaises(TypeError): self._event_log_module.Write(path=1234) def test_write_socket(self): """Test Write() with Unix domain socket for |path| and validate received traces.""" received_traces = [] with tempfile.TemporaryDirectory( prefix="test_server_sockets" ) as tempdir: socket_path = os.path.join(tempdir, "server.sock") server_ready = threading.Condition() # Start "server" listening on Unix domain socket at socket_path. try: server_thread = threading.Thread( target=serverLoggingThread, args=(socket_path, server_ready, received_traces), ) server_thread.start() with server_ready: server_ready.wait(timeout=120) self._event_log_module.StartEvent() path = self._event_log_module.Write( path=f"af_unix:{socket_path}" ) finally: server_thread.join(timeout=5) self.assertEqual(path, f"af_unix:stream:{socket_path}") self.assertEqual(len(received_traces), 2) version_event = json.loads(received_traces[0]) start_event = json.loads(received_traces[1]) self.verifyCommonKeys(version_event, expected_event_name="version") self.verifyCommonKeys(start_event, expected_event_name="start") # Check for 'start' event specific fields. self.assertIn("argv", start_event) self.assertIsInstance(start_event["argv"], list)