"""
Test Context for BLE tests.
Provides environment for test execution.
"""
import asyncio
import logging
import time
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.shortcuts import PromptSession
from .ble_manager import BLEManager
logger = logging.getLogger(__name__)
# Decorator for test functions
[docs]
def ble_test(description=None):
"""
Decorate a BLE test function.
Args:
description: Description of the test (optional, will use function name if not provided)
Returns:
Decorated function
"""
def decorator(func):
# Set attributes on the function
func._is_ble_test = True
func._test_description = description
return func
# Handle case where decorator is used without arguments
if callable(description):
func = description
description = None
return decorator(func)
return decorator
# Decorator for test classes
[docs]
def ble_test_class(description=None):
"""
Decorate a BLE test class.
Args:
description: Description of the test class (optional, will use class name if not provided)
Returns:
Decorated class
"""
def decorator(cls):
# Set attributes on the class
cls._is_test_class = True
cls._test_description = description
return cls
# Handle case where decorator is used without arguments
if callable(description):
cls = description
description = None
return decorator(cls)
return decorator
[docs]
class TestStatus(Enum):
"""Enum for test execution status."""
PASS = "pass" # nosec B105
FAIL = "fail"
SKIP = "skip"
ERROR = "error"
RUNNING = "running"
def __str__(self):
"""Return the string representation of the test status."""
return self.value
[docs]
class TestException(Exception):
"""Base class for test exceptions."""
status = TestStatus.ERROR
def __init__(self, message=""):
"""Initialize the test exception."""
self.message = message
super().__init__(message)
[docs]
class TestFailure(TestException):
"""Exception raised when a test fails."""
status = TestStatus.FAIL
[docs]
class TestSkip(TestException):
"""Exception raised when a test is skipped."""
status = TestStatus.SKIP
[docs]
class NotificationResult(Enum):
"""Enum for notification evaluation results."""
IGNORE = "ignore" # Not what we're looking for, continue waiting
MATCH = "match" # Found what we were looking for, success
FAIL = "fail" # Found something that indicates a failure condition
def __str__(self):
"""Return the string representation of the notification result."""
return self.value
# Type alias for notification expected value
# Can be bytes for exact matching, a callable for custom evaluation, or None to match any notification
# The callable should return a boolean (pass or fail), a notification result enum, or a tuple of
# (NotificationResult, str)
# If the callable returns a NotificationResult of FAIL, the reason should be provided in the str
NotificationExpectedValue = Optional[
Union[
bytes,
Callable[[bytes], Union[bool, NotificationResult, Tuple[NotificationResult, str]]],
]
]
[docs]
class NotificationWaiter:
"""Helper class to wait for notifications."""
def __init__(self, characteristic_uuid: str, expected_value: NotificationExpectedValue = None):
"""Initialize the notification waiter."""
self.characteristic_uuid = characteristic_uuid
self.expected_value = expected_value
self.received_notifications = []
self.matching_notification = None # Will store the matching notification data
self.failure_reason = None # Will store failure message if applicable
self.complete_event = asyncio.Event()
[docs]
def check_notification(self, data: bytes) -> Tuple[bool, Optional[str]]:
"""
Check if a notification matches our criteria.
Args:
data: The notification data to check
Returns:
Tuple of (is_match, failure_reason)
- is_match: True if the notification matches expected criteria
- failure_reason: If the notification indicates a failure condition, the reason
"""
current_expected = self.expected_value
if current_expected is None:
# No expected value - any notification is a match
return True, None
if callable(current_expected):
# User provided a lambda/function to evaluate the notification
try:
result = current_expected(data)
if isinstance(result, tuple) and len(result) == 2 and isinstance(result[0], NotificationResult):
# Handle tuple return format: (NotificationResult, Optional[str])
notification_result, reason = result
if notification_result == NotificationResult.MATCH:
return True, None
elif notification_result == NotificationResult.FAIL:
return (
False,
reason or f"Notification evaluated as failure condition: {data.hex()}",
)
else: # IGNORE
return False, None
elif isinstance(result, NotificationResult):
# Handle direct NotificationResult enum
if result == NotificationResult.MATCH:
return True, None
elif result == NotificationResult.FAIL:
return (
False,
f"Notification evaluated as failure condition: {data.hex()}",
)
else: # IGNORE
return False, None
else:
# True = match, False = ignore (not a failure)
return bool(result), None
except Exception as e:
# If the function raises an exception, log it but don't fail
logger.error(f"Error in notification evaluation function: {e}")
return False, None
else:
# Direct comparison with expected bytes
return (
data == current_expected,
f"Notification {data.hex()} ≠ {current_expected.hex()} expected value",
)
[docs]
def on_notification(self, data) -> bool:
"""
Handle a notification.
Args:
data: The notification data
Returns:
True if the notification matches the expected value, False otherwise
"""
if self.complete_event.is_set():
return False
# Store all notifications we receive
self.received_notifications.append(data)
# Check immediately if this is the notification we're waiting for
is_match, failure_reason = self.check_notification(data)
if is_match:
logger.debug("Found matching notification in callback - setting event")
self.matching_notification = data
self.complete_event.set()
return True
elif failure_reason:
# We have a failure condition from the notification
logger.debug(f"Notification indicates failure: {failure_reason}")
self.failure_reason = failure_reason
self.complete_event.set()
return False
else:
logger.debug("Notification in callback didn't match criteria")
return False
[docs]
class NotificationSubscription:
"""A helper class to manage notification subscriptions and waiters."""
def __init__(self, characteristic_uuid: str, initial_waiter: NotificationWaiter = None):
"""Initialize the notification subscription."""
self.characteristic_uuid = characteristic_uuid
self.waiter = initial_waiter
self.collected_notifications = []
[docs]
def on_notification(self, data):
"""Handle a notification."""
self.collected_notifications.append(data)
logger.debug(
f"Notification callback received: {data.hex() if data else 'None'}, "
f"{len(self.collected_notifications)} notifications collected"
)
if self.waiter is None:
return
if self.waiter.on_notification(data):
self.collected_notifications.clear()
[docs]
def set_waiter(self, waiter: NotificationWaiter, process_collected_notifications: bool = True):
"""Set the waiter for the subscription."""
self.waiter = waiter
if process_collected_notifications:
logger.debug(f"Processing {len(self.collected_notifications)} collected notifications")
for i in range(len(self.collected_notifications)):
if self.waiter.on_notification(self.collected_notifications[i]):
# If we found a match, clear all notifications up to and including the current one
self.collected_notifications = self.collected_notifications[i + 1 :]
break
[docs]
def clear_waiter(self):
"""Clear the waiter for the subscription."""
self.waiter = None
[docs]
class TestContext:
"""
Context for test execution.
Provides access to the BLE device and helper methods for test operations.
"""
def __init__(self, ble_manager: BLEManager):
"""Initialize the test context."""
self.ble_manager = ble_manager
self.start_time = time.time()
self.test_results: Dict[str, Dict[str, Any]] = {}
self.current_test: Optional[str] = None
self.notification_subscriptions: Dict[str, NotificationSubscription] = {}
[docs]
def print_formatted_box(self, title: str, messages: List[str]) -> None:
"""
Print a formatted box with consistent alignment.
Args:
title: The title to display at the top of the box
messages: List of message lines to display in the box
"""
# Box width (including borders)
box_width = 80
content_width = box_width - 4 # Allow for borders and spaces
# Print top border
print("\n╔" + "═" * (box_width - 2) + "╗")
# Print title if provided
if title:
# Ensure title fits in box with proper padding
if len(title) > content_width:
title = title[: content_width - 3] + "..."
padding = " " * (content_width - len(title))
print(f"║ {title}{padding} ║")
# Print messages
for message in messages:
# Pre-process message to handle newlines properly
message_parts = message.split("\n")
for part in message_parts:
# Split long lines into multiple lines with word wrap
remaining = part
while remaining:
if len(remaining) <= content_width:
# Line fits, use it completely
line = remaining
remaining = ""
else:
# Try to break at word boundary
split_pos = remaining[:content_width].rfind(" ")
if split_pos <= 0: # No space found or at beginning, just split at max length
split_pos = content_width
line = remaining[:split_pos].rstrip()
remaining = remaining[split_pos:].lstrip() if split_pos < len(remaining) else ""
# Ensure padding for a uniform right edge
padding = " " * (content_width - len(line))
print(f"║ {line}{padding} ║")
# Print bottom border
print("╚" + "═" * (box_width - 2) + "╝")
[docs]
def print(self, message: str) -> None:
"""
Print a message directly to the console for user-facing output.
Use this for information that should always be visible to the user,
regardless of log level settings.
Args:
message: The message to display to the user
"""
# Print the message with ANSI codes intact
print(message)
# Strip ANSI escape codes for logging
import re
ansi_escape = re.compile(r"\033\[[0-9;]*[a-zA-Z]")
clean_message = ansi_escape.sub("", message)
# Also log the message at INFO level for record keeping
logger.info(clean_message)
# Store in test results with level information
if self.current_test:
self.test_results[self.current_test]["logs"].append(
{
"timestamp": time.time(),
"level": "USER", # Special level to mark user-facing output
"message": clean_message,
}
)
[docs]
def prompt_user(self, message: str) -> str:
"""
Display a prompt to the user and wait for input.
Args:
message: The message to display to the user
Returns:
User's input response
"""
# Use the formatted box function
self.print_formatted_box("USER ACTION REQUIRED", [message])
response = input("Enter your response and press Enter to continue: ")
logger.info(f"User response: {response}")
return response
[docs]
def start_test(self, test_name: str) -> None:
"""
Start a new test and record the start time.
Args:
test_name: Name of the test being started
"""
self.current_test = test_name
self.test_results[test_name] = {
"start_time": time.time(),
"status": TestStatus.RUNNING.value,
"duration": 0,
"logs": [],
}
logger.debug(f"Starting test: {test_name}")
[docs]
async def unsubscribe_all(self) -> None:
"""
Unsubscribe from all active notification subscriptions.
Call this at the end of a test to clean up resources.
"""
if not self.notification_subscriptions:
return
# Make a copy of the keys since we'll be modifying the dictionary
characteristics = list(self.notification_subscriptions.keys())
for characteristic_uuid in characteristics:
try:
logger.debug(f"Unsubscribing from {characteristic_uuid}")
await self.ble_manager.unsubscribe_from_characteristic(characteristic_uuid)
# Remove from subscriptions
self.notification_subscriptions.pop(characteristic_uuid, None)
logger.debug(f"Successfully unsubscribed from {characteristic_uuid}")
except Exception as e:
logger.error(f"Error unsubscribing from {characteristic_uuid}: {str(e)}")
logger.debug(f"Unsubscribed from all {len(characteristics)} active characteristics")
[docs]
async def cleanup_tasks(self):
"""
Clean up any remaining async tasks created during testing.
This should be called before program exit.
"""
# Unsubscribe from all notifications
await self.unsubscribe_all()
# Clear any remaining state
logger.debug("Cleanup tasks completed")
[docs]
def end_test(self, status: Union[TestStatus, str], message: str = "") -> Dict[str, Any]:
"""
End the current test and record results.
Args:
status: Test status (TestStatus enum or string value)
message: Optional message about test result
Returns:
Test result details
"""
if not self.current_test:
logger.warning("Attempted to end test but no test is currently running")
return {}
end_time = time.time()
test_name = self.current_test
# Convert string status to enum if needed
if isinstance(status, str):
try:
# Try to convert the string to enum
status = next(s for s in TestStatus if s.value == status)
except StopIteration:
logger.warning(f"Unknown test status '{status}', using as-is")
# Get the string value if it's an enum
status_value = status.value if isinstance(status, TestStatus) else status
# Update test results
self.test_results[test_name].update(
{
"end_time": end_time,
"duration": end_time - self.test_results[test_name]["start_time"],
"status": status_value,
"message": message,
}
)
# Define color codes for different statuses
status_display = {
TestStatus.PASS.value: "\033[92mPASSED ✓\033[0m", # Green
TestStatus.FAIL.value: "\033[91mFAILED ✗\033[0m", # Red
TestStatus.SKIP.value: "\033[93mSKIPPED -\033[0m", # Yellow
TestStatus.ERROR.value: "\033[93mERROR !\033[0m", # Yellow
}.get(status_value, status_value.upper())
# Print a simpler message instead of a formatted box
# Only use formatted boxes for things that need user attention
duration = self.test_results[test_name]["duration"]
print("") # Add space before result for visual separation
self.print(f"Test {test_name} {status_display} in {duration:.2f}s" + (f": {message}" if message else ""))
print("") # Add space after result
# Also log for record keeping
logger.info(f"Test {test_name} {status_value}{': ' + message if message else ''}")
logger.debug(f"Test duration: {self.test_results[test_name]['duration']:.2f} seconds")
# Reset current test
self.current_test = None
# Return the results
return self.test_results[test_name]
[docs]
def log(self, message: str, level: str = "info") -> None:
"""
Log a message within the current test context.
Args:
message: Message to log
level: Log level (debug, info, warning, error, critical)
"""
# Convert string level to logging level
log_level = getattr(logging, level.upper(), logging.INFO)
# Always store in test results with level information for later retrieval
if self.current_test:
self.test_results[self.current_test]["logs"].append(
{"timestamp": time.time(), "level": level.upper(), "message": message}
)
# Only display INFO and DEBUG logs in the console if the test fails
# WARNING, ERROR, CRITICAL logs are always displayed
if log_level >= logging.WARNING: # WARNING=30, ERROR=40, CRITICAL=50
# Always log warnings, errors, and critical messages
logger.log(log_level, message)
else:
# For INFO and DEBUG logs, we only store them but don't display during test execution
# They will be displayed in the results summary if the test fails
pass
[docs]
def debug(self, message: str) -> None:
"""Log a debug message within the current test context."""
self.log(message, level="debug")
[docs]
def info(self, message: str) -> None:
"""Log an info message within the current test context."""
self.log(message, level="info")
[docs]
def warning(self, message: str) -> None:
"""Log a warning message within the current test context."""
self.log(message, level="warning")
[docs]
def error(self, message: str) -> None:
"""Log an error message within the current test context."""
self.log(message, level="error")
[docs]
def critical(self, message: str) -> None:
"""Log a critical message within the current test context."""
self.log(message, level="critical")
[docs]
async def subscribe_to_characteristic(
self,
characteristic_uuid: str,
waiter: Optional[NotificationWaiter] = None,
process_collected_notifications: bool = True,
):
"""
Subscribe to a characteristic and create a waiter if provided.
Args:
characteristic_uuid: UUID of characteristic to subscribe to
waiter: Optional NotificationWaiter instance to use
process_collected_notifications: If True, process collected notifications
"""
# Only subscribe if not already subscribed
if characteristic_uuid not in self.notification_subscriptions:
try:
logger.debug(f"Subscribing to characteristic {characteristic_uuid}")
# Create the waiter first
sub = NotificationSubscription(characteristic_uuid, waiter)
self.notification_subscriptions[characteristic_uuid] = sub
# Now subscribe with on_notification
await self.ble_manager.subscribe_to_characteristic(characteristic_uuid, sub.on_notification)
logger.debug(f"Successfully subscribed to {characteristic_uuid}")
# Short delay to ensure subscription is active
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"Error subscribing to characteristic: {str(e)}")
# Remove the waiter if we failed to subscribe
if characteristic_uuid in self.notification_subscriptions:
del self.notification_subscriptions[characteristic_uuid]
raise RuntimeError(f"Failed to subscribe: {str(e)}")
else:
# Already subscribed - reuse the existing subscription
logger.debug(f"Using existing subscription to {characteristic_uuid}")
# Get existing subscription and update the waiter
sub = self.notification_subscriptions[characteristic_uuid]
if waiter:
sub.set_waiter(waiter, process_collected_notifications)
else:
sub.clear_waiter()
return sub
[docs]
async def create_notification_waiter(
self,
characteristic_uuid: str,
expected_value: NotificationExpectedValue = None,
process_collected_notifications: bool = True,
) -> NotificationWaiter:
"""
Create a notification waiter for a characteristic.
Args:
characteristic_uuid: UUID of characteristic to wait for notification
expected_value: If provided, validates the notification value. Can be:
- bytes: exact value to match
- callable: function that takes the notification data and returns a NotificationResult
Returns:
NotificationWaiter instance
"""
waiter = NotificationWaiter(characteristic_uuid, expected_value)
await self.subscribe_to_characteristic(characteristic_uuid, waiter, process_collected_notifications)
return waiter
[docs]
def handle_notification_waiter_result(self, waiter: NotificationWaiter, timeout: float) -> Dict[str, Any]:
"""
Handle the result of a notification waiter.
Args:
waiter: The notification waiter to check
timeout: The timeout value that was used (for error messages)
Returns:
Dictionary with notification details if successful:
'value': The notification value that matched the expected value (bytes)
'success': True if notification received and matched expected
'received_notifications': List of all notifications received (list of bytes)
Raises:
TestFailure: If a notification indicates a test failure
TimeoutError: If no notification was received within the timeout period
"""
if waiter.matching_notification:
logger.debug(
"Found matching notification: "
f"{waiter.matching_notification.hex() if waiter.matching_notification else 'None'}"
)
return {
"value": waiter.matching_notification,
"success": True,
"received_notifications": waiter.received_notifications,
}
elif waiter.failure_reason:
# We got a failure notification
logger.info(f"Test failed due to notification: {waiter.failure_reason}")
raise TestFailure(waiter.failure_reason)
elif waiter.received_notifications:
# We got notifications but none matched our expected value
logger.info(
f"Received {len(waiter.received_notifications)} notifications, but none matched the expected value"
)
for i, notif in enumerate(waiter.received_notifications):
logger.debug(f"Notification {i+1}: {notif.hex() if notif else 'None'}")
# Raise exception for non-matching notifications
raise TestFailure(
f"No matching notification received. Got: {', '.join(n.hex() for n in waiter.received_notifications)}"
)
else:
# Raise timeout error with user-friendly message
raise TimeoutError(f"No notification received within {timeout} seconds")
[docs]
async def wait_for_notification(
self,
characteristic_uuid: str,
timeout: float = 10.0,
expected_value: NotificationExpectedValue = None,
process_collected_notifications: bool = True,
) -> Dict[str, Any]:
"""
Wait for a notification from a characteristic without user interaction.
Args:
characteristic_uuid: UUID of characteristic to wait for notification
timeout: Maximum time to wait in seconds
expected_value: If provided, validates the notification value. Can be:
- bytes: exact value to match
- callable: function that takes the notification data and returns a NotificationResult
Returns:
Dictionary with notification details:
'value': The notification value
'success': True if notification received and matched expected
'received_notifications': List of all notifications received
Raises:
TimeoutError: If no notification is received within the timeout
TestFailure: If a notification is received but doesn't match expected criteria
"""
waiter = await self.create_notification_waiter(
characteristic_uuid, expected_value, process_collected_notifications
)
try:
# Create a task that will complete when a notification is received
notification_future = asyncio.create_task(waiter.complete_event.wait())
# Wait for notification or timeout
try:
await asyncio.wait_for(notification_future, timeout)
logger.debug("Notification received before timeout")
except asyncio.TimeoutError:
logger.info(f"Timed out waiting for notification after {timeout} seconds")
if notification_future.cancel():
logger.debug("Successfully cancelled notification future")
finally:
pass # We'll keep the subscription active for potential future notifications
return self.handle_notification_waiter_result(waiter, timeout)
[docs]
async def wait_for_notification_interactive(
self,
characteristic_uuid: str,
timeout: float = 10.0,
expected_value: NotificationExpectedValue = None,
) -> Dict[str, Any]:
"""
Wait for a notification from a characteristic with user interaction support.
This method will display a prompt to the user and wait for a notification.
The user can type 's' or 'skip' to skip the test, or 'f' or 'fail' to fail it.
If the user chooses to skip or fail, the appropriate TestSkip or TestFailure
exception will be raised automatically.
Args:
characteristic_uuid: UUID of characteristic to wait for notification
timeout: Maximum time to wait in seconds
expected_value: If provided, validates the notification value. Can be:
- bytes: exact value to match
- callable: function that takes the notification data and returns a NotificationResult
Returns:
Dictionary with notification details:
'value': The notification value
'success': True if notification received and matched expected
'received_notifications': List of all notifications received
Raises:
TestSkip: If the user chooses to skip the test
TestFailure: If the user chooses to fail the test
TimeoutError: If no notification is received within the timeout
"""
async def user_input_handler() -> Tuple[str, str]:
"""Handle user input during the waiting period."""
print("\nThe test will continue automatically when event is detected.")
print(
"If nothing happens, type 's' or 'skip' to skip, 'f' or 'fail' to fail the test, or 'd' for debug info."
)
session = PromptSession()
with patch_stdout():
while True:
user_input = None
try:
user_input = await session.prompt_async()
except (EOFError, KeyboardInterrupt):
# Handle Ctrl+D and Ctrl+C gracefully
user_input = "f" # Treat as "fail" to abort the test
except asyncio.CancelledError:
# Task cancelled - exit cleanly
logger.debug("User input task cancelled")
break
except Exception as e:
logger.error(f"Error in user input handler: {e}")
break
# Handle EOF or errors
if not user_input:
logger.info("Input stream closed or returned empty input")
break
user_input = user_input.strip().lower()
# Process based on user input
if user_input in ["s", "skip"]:
return ("skip", "User chose to skip the test")
elif user_input in ["f", "fail"]:
return ("fail", "User reported test failure")
elif user_input == "d":
# Debug - show received notifications
if characteristic_uuid not in self.notification_subscriptions:
print(f"No subscription to {characteristic_uuid}")
continue
sub = self.notification_subscriptions[characteristic_uuid]
if sub.waiter is None:
print(f"No waiter for {characteristic_uuid}")
continue
if len(sub.waiter.received_notifications) == 0:
print(f"No notifications received for {characteristic_uuid}")
continue
print(f"Received {len(sub.waiter.received_notifications)} notifications so far:")
for i, n in enumerate(sub.waiter.received_notifications):
print(f" Notification {i+1}: {n.hex() if n else 'None'}")
is_match, _ = sub.waiter.check_notification(n)
if is_match:
print(" --> This notification MATCHES the expected criteria")
else:
print(" --> Does NOT match expected criteria")
# Continue waiting - don't break the loop
else:
print("Invalid input. Type 's' to skip, 'f' to fail, or 'd' for debug info.")
waiter = await self.create_notification_waiter(characteristic_uuid, expected_value, False)
# Start user input handler
user_input_task = asyncio.create_task(user_input_handler())
# Create a task for monitoring the notification event
notification_task = asyncio.create_task(waiter.complete_event.wait())
try:
# Wait for the first of the tasks to complete or for timeout
done, pending = await asyncio.wait(
[notification_task, user_input_task],
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED,
)
# Determine which task completed first
if notification_task in done:
logger.info("Notification task completed first")
return self.handle_notification_waiter_result(waiter, timeout)
elif user_input_task in done:
# User input finished first
user_response, message = user_input_task.result()
logger.info(f"User input task completed first: {message}")
# Raise appropriate exception based on user input
if user_response == "skip":
raise TestSkip("User chose to skip test")
elif user_response == "fail":
raise TestFailure("User reported test failure")
else:
logger.info("Timeout occurred while waiting for notification or user input")
# Cancel any pending tasks
for task in pending:
task.cancel()
finally:
# Make sure all tasks are cancelled
for task in [notification_task, user_input_task]:
if not task.done():
task.cancel()
try:
await asyncio.wait_for(task, timeout=0.1)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
[docs]
def get_test_summary(self) -> Dict[str, Any]:
"""
Generate a summary of all test results.
Returns:
Dictionary with test summary statistics
"""
# Filter out tests that are still in 'running' status - these are duplicate entries
completed_results = {
name: result
for name, result in self.test_results.items()
if result.get("status") != TestStatus.RUNNING.value
}
total_tests = len(completed_results)
passed_tests = sum(1 for result in completed_results.values() if result["status"] == TestStatus.PASS.value)
failed_tests = sum(1 for result in completed_results.values() if result["status"] == TestStatus.FAIL.value)
total_duration = sum(result["duration"] for result in completed_results.values() if "duration" in result)
return {
"total_tests": total_tests,
"passed_tests": passed_tests,
"failed_tests": failed_tests,
"total_duration": total_duration,
"results": self.test_results, # Return all results for debugging, filtering happens in CLI
}