add router shutdown and task cancellation control

This commit is contained in:
Sam G. 2024-05-12 00:46:08 -07:00
parent 456fa7c84f
commit f9de5cfe4a
5 changed files with 104 additions and 35 deletions

View File

@ -401,6 +401,5 @@ class PathListener(Listener[FileEvent]):
self.write.write(b"\x00") self.write.write(b"\x00")
self.write.close() self.write.close()
if self.router.thread_pool is not None: self.router.shutdown()
self.router.thread_pool.shutdown()

View File

@ -9,16 +9,16 @@ import traceback
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any, Callable
from colorama import Fore, Style
from collections import defaultdict from collections import defaultdict
from colorama import Fore, Back, Style
from functools import partial, update_wrapper from functools import partial, update_wrapper
from concurrent.futures import ThreadPoolExecutor, wait, as_completed from concurrent.futures import ThreadPoolExecutor, wait, as_completed
from tqdm.auto import tqdm from tqdm.auto import tqdm
from execlog.util.generic import color_text
from execlog.event import Event from execlog.event import Event
from execlog.listener import Listener from execlog.listener import Listener
from execlog.util.generic import color_text
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -111,6 +111,10 @@ class Router[E: Event]:
# track event history # track event history
self.event_log = [] self.event_log = []
# shutdown flag, mostly for callbacks
self.should_exit = False
self._active_futures = set()
self._thread_pool = None self._thread_pool = None
self._route_lock = threading.Lock() self._route_lock = threading.Lock()
@ -229,6 +233,11 @@ class Router[E: Event]:
Note: this method is expected to return a future. Perform any event-based Note: this method is expected to return a future. Perform any event-based
filtering before submitting a callback with this method. filtering before submitting a callback with this method.
''' '''
# exit immediately if exit flag is set
# if self.should_exit:
# return
callback = self.wrap_safe_callback(callback)
if inspect.iscoroutinefunction(callback): if inspect.iscoroutinefunction(callback):
if self.loop is None: if self.loop is None:
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
@ -245,7 +254,8 @@ class Router[E: Event]:
future = self.thread_pool.submit( future = self.thread_pool.submit(
callback, *args, **kwargs callback, *args, **kwargs
) )
future.add_done_callback(handle_exception) self._active_futures.add(future)
future.add_done_callback(self.general_task_done)
return future return future
@ -354,9 +364,10 @@ class Router[E: Event]:
future_results = [] future_results = []
for future in as_completed(futures): for future in as_completed(futures):
try: try:
future_results.append(future.result()) if not future.cancelled():
future_results.append(future.result())
except Exception as e: except Exception as e:
logger.warning(f"Router callback job failed with exception {e}") logger.warning(f"Router callback job failed with exception \"{e}\"")
return future_results return future_results
@ -375,6 +386,22 @@ class Router[E: Event]:
''' '''
self.running_events[event_idx].update(callbacks) self.running_events[event_idx].update(callbacks)
def wrap_safe_callback(self, callback: Callable):
'''
Check for shutdown flag and exit before running the callbacks.
Applies primarily to jobs enqueued by the ThreadPoolExecutor but not started when
an interrupt is received.
'''
def safe_callback(callback, *args, **kwargs):
if self.should_exit:
logger.debug('Exiting early from queued callback')
return
return callback(*args, **kwargs)
return partial(safe_callback, callback)
def filter(self, event: E, pattern, **listen_kwargs) -> bool: def filter(self, event: E, pattern, **listen_kwargs) -> bool:
''' '''
Determine if a given event matches the provided pattern Determine if a given event matches the provided pattern
@ -435,8 +462,15 @@ class Router[E: Event]:
The check for results from the passed future allows us to know when in fact a The check for results from the passed future allows us to know when in fact a
valid frame has finished, and a resubmission may be on the table. valid frame has finished, and a resubmission may be on the table.
''' '''
result = future.result() result = None
if not result: return if not future.cancelled():
result = future.result()
else:
return None
# result should be *something* if work was scheduled
if not result:
return None
self.event_log.append((event, result)) self.event_log.append((event, result))
queued_callbacks = self.stop_event(event) queued_callbacks = self.stop_event(event)
@ -451,6 +485,29 @@ class Router[E: Event]:
def event_index(self, event): def event_index(self, event):
return event[:2] return event[:2]
def shutdown(self):
logger.info(color_text('Router shutdown received', Fore.BLACK, Back.RED))
self.should_exit = True
for future in tqdm(
list(self._active_futures),
desc=color_text('Cancelling active futures...', Fore.BLACK, Back.RED),
colour='red',
):
future.cancel()
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=False)
def general_task_done(self, future):
self._active_futures.remove(future)
try:
if not future.cancelled():
future.result()
except Exception as e:
logger.error(f"Exception occurred in threaded task: '{e}'")
#traceback.print_exc()
class ChainRouter[E: Event](Router[E]): class ChainRouter[E: Event](Router[E]):
''' '''
@ -545,14 +602,6 @@ class ChainRouter[E: Event](Router[E]):
return listener return listener
def handle_exception(future):
try:
future.result()
except Exception as e:
print(f"Exception occurred: {e}")
traceback.print_exc()
# RouterBuilder # RouterBuilder
def route(router, route_group, **route_kwargs): def route(router, route_group, **route_kwargs):
def decorator(f): def decorator(f):
@ -666,7 +715,9 @@ class RouterBuilder(ChainRouter, metaclass=RouteRegistryMeta):
# assumed no kwargs for passthrough # assumed no kwargs for passthrough
if route_group == 'post': if route_group == 'post':
for method, _ in method_arg_list: for method, _ in method_arg_list:
router.add_post_callback(method) router.add_post_callback(
update_wrapper(partial(method, self), method),
)
continue continue
group_options = router_options.get(route_group) group_options = router_options.get(route_group)

View File

@ -15,7 +15,6 @@ listeners in one place.
make the Server definition more flexible. make the Server definition more flexible.
''' '''
import re import re
import signal
import asyncio import asyncio
import logging import logging
import threading import threading
@ -75,6 +74,8 @@ class Server:
self.server_text = '' self.server_text = ''
self.server_args = {} self.server_args = {}
self.started = False
self.loop = None self.loop = None
self._server_setup() self._server_setup()
@ -228,6 +229,8 @@ class Server:
if not listener.started: if not listener.started:
listener.start() listener.start()
self.started = False
if self.server: if self.server:
logger.info(f'Server{self.server_text} @ http://{self.host}:{self.port}') logger.info(f'Server{self.server_text} @ http://{self.host}:{self.port}')
@ -292,7 +295,6 @@ class Server:
self.loop.call_soon_threadsafe(set_should_exit) self.loop.call_soon_threadsafe(set_should_exit)
class ListenerServer: class ListenerServer:
''' '''
Server abstraction to handle disparate listeners. Server abstraction to handle disparate listeners.
@ -305,16 +307,16 @@ class ListenerServer:
managed_listeners = [] managed_listeners = []
self.managed_listeners = managed_listeners self.managed_listeners = managed_listeners
self.started = False
def start(self): def start(self):
signal.signal(signal.SIGINT, lambda s,f: self.shutdown())
signal.signal(signal.SIGTERM, lambda s,f: self.shutdown())
for listener in self.managed_listeners: for listener in self.managed_listeners:
#loop.run_in_executor(None, partial(self.listener.start, loop=loop)) #loop.run_in_executor(None, partial(self.listener.start, loop=loop))
if not listener.started: if not listener.started:
listener.start() listener.start()
self.started = True
for listener in self.managed_listeners: for listener in self.managed_listeners:
listener.join() listener.join()

View File

@ -1,12 +1,21 @@
import logging
from pathlib import Path from pathlib import Path
from concurrent.futures import as_completed
from tqdm import tqdm
from colorama import Fore, Back, Style
from inotify_simple import flags as iflags
from co3.resources import DiskResource from co3.resources import DiskResource
from co3 import Differ, Syncer, Database from co3 import Differ, Syncer, Database
from execlog.event import Event from execlog.event import Event
from execlog.routers import PathRouter from execlog.routers import PathRouter
from execlog.util.generic import color_text
logger = logging.getLogger(__name__)
class PathDiffer(Differ[Path]): class PathDiffer(Differ[Path]):
def __init__( def __init__(
self, self,
@ -101,16 +110,18 @@ class PathRouterSyncer(Syncer[Path]):
''' '''
return [ return [
self._construct_event(str(path), endpoint, iflags.MODIFY) self._construct_event(str(path), endpoint, iflags.MODIFY)
for endpoint, _ in path_tuples[1] for endpoint, _ in path_tuples[0]
] ]
def filter_diff_sets(self, l_excl, r_excl, lr_int): def filter_diff_sets(self, l_excl, r_excl, lr_int):
total_disk_files = len(l_excl) + len(lr_int) total_disk_files = len(l_excl) + len(lr_int)
total_joint_files = len(lr_int)
def file_out_of_sync(p): def file_out_of_sync(p):
db_el, disk_el = lr_int[p] _, db_el = lr_int[p]
db_mtime = float(db_el[0].get('mtime','0')) db_mtime = float(db_el[0].get('mtime','0'))
disk_mtime = File(p, disk_el[0]).mtime disk_mtime = Path(p).stat().st_mtime
return disk_mtime > db_mtime return disk_mtime > db_mtime
lr_int = {p:v for p,v in lr_int.items() if file_out_of_sync(p)} lr_int = {p:v for p,v in lr_int.items() if file_out_of_sync(p)}
@ -119,10 +130,10 @@ class PathRouterSyncer(Syncer[Path]):
oos_count = len(l_excl) + len(lr_int) oos_count = len(l_excl) + len(lr_int)
oos_prcnt = oos_count / max(total_disk_files, 1) * 100 oos_prcnt = oos_count / max(total_disk_files, 1) * 100
logger.info(color_text(Fore.GREEN, f'{len(l_excl)} new files to add')) logger.info(color_text(f'{len(l_excl)} new files to add', Fore.GREEN)),
logger.info(color_text(Fore.YELLOW, f'{len(lr_int)} modified files')) logger.info(color_text(f'{len(lr_int)} modified files [{total_joint_files} up-to-date]', Fore.YELLOW)),
logger.info(color_text(Fore.RED, f'{len(r_excl)} files to remove')) logger.info(color_text(f'{len(r_excl)} files to remove', Fore.RED)),
logger.info(color_text(Style.DIM, f'({oos_prcnt:.2f}%) of disk files out-of-sync')) logger.info(color_text(f'({oos_prcnt:.2f}%) of disk files out-of-sync', Style.DIM)),
return l_excl, r_excl, lr_int return l_excl, r_excl, lr_int
@ -137,12 +148,18 @@ class PathRouterSyncer(Syncer[Path]):
results = [] results = []
for future in tqdm( for future in tqdm(
as_completed(event_futures), as_completed(event_futures),
total=chunk_size, total=len(event_futures),
desc=f'Awaiting chunk futures [submitted {len(event_futures)}/{chunk_size}]' desc=f'Awaiting chunk futures [submitted {len(event_futures)}]'
): ):
try: try:
results.append(future.result()) if not future.cancelled():
results.append(future.result())
except Exception as e: except Exception as e:
logger.warning(f"Sync job failed with exception {e}") logger.warning(f"Sync job failed with exception {e}")
return results return results
def shutdown(self):
super().shutdown()
self.router.shutdown()

View File

@ -39,8 +39,8 @@ class ColorFormatter(logging.Formatter):
formatter = self.FORMATS[submodule] formatter = self.FORMATS[submodule]
name = record.name name = record.name
if package == 'localsys': if package == 'execlog':
name = f'localsys.{subsubmodule}' name = f'execlog.{subsubmodule}'
limit = 26 limit = 26
name = name[:limit] name = name[:limit]