add router shutdown and task cancellation control

This commit is contained in:
Sam G. 2024-05-12 00:46:08 -07:00
parent 456fa7c84f
commit f9de5cfe4a
5 changed files with 104 additions and 35 deletions

View File

@ -401,6 +401,5 @@ class PathListener(Listener[FileEvent]):
self.write.write(b"\x00")
self.write.close()
if self.router.thread_pool is not None:
self.router.thread_pool.shutdown()
self.router.shutdown()

View File

@ -9,16 +9,16 @@ import traceback
import threading
from pathlib import Path
from typing import Any, Callable
from colorama import Fore, Style
from collections import defaultdict
from colorama import Fore, Back, Style
from functools import partial, update_wrapper
from concurrent.futures import ThreadPoolExecutor, wait, as_completed
from tqdm.auto import tqdm
from execlog.util.generic import color_text
from execlog.event import Event
from execlog.listener import Listener
from execlog.util.generic import color_text
logger = logging.getLogger(__name__)
@ -111,6 +111,10 @@ class Router[E: Event]:
# track event history
self.event_log = []
# shutdown flag, mostly for callbacks
self.should_exit = False
self._active_futures = set()
self._thread_pool = None
self._route_lock = threading.Lock()
@ -229,6 +233,11 @@ class Router[E: Event]:
Note: this method is expected to return a future. Perform any event-based
filtering before submitting a callback with this method.
'''
# exit immediately if exit flag is set
# if self.should_exit:
# return
callback = self.wrap_safe_callback(callback)
if inspect.iscoroutinefunction(callback):
if self.loop is None:
self.loop = asyncio.new_event_loop()
@ -245,7 +254,8 @@ class Router[E: Event]:
future = self.thread_pool.submit(
callback, *args, **kwargs
)
future.add_done_callback(handle_exception)
self._active_futures.add(future)
future.add_done_callback(self.general_task_done)
return future
@ -354,9 +364,10 @@ class Router[E: Event]:
future_results = []
for future in as_completed(futures):
try:
if not future.cancelled():
future_results.append(future.result())
except Exception as e:
logger.warning(f"Router callback job failed with exception {e}")
logger.warning(f"Router callback job failed with exception \"{e}\"")
return future_results
@ -375,6 +386,22 @@ class Router[E: Event]:
'''
self.running_events[event_idx].update(callbacks)
def wrap_safe_callback(self, callback: Callable):
'''
Check for shutdown flag and exit before running the callbacks.
Applies primarily to jobs enqueued by the ThreadPoolExecutor but not started when
an interrupt is received.
'''
def safe_callback(callback, *args, **kwargs):
if self.should_exit:
logger.debug('Exiting early from queued callback')
return
return callback(*args, **kwargs)
return partial(safe_callback, callback)
def filter(self, event: E, pattern, **listen_kwargs) -> bool:
'''
Determine if a given event matches the provided pattern
@ -435,8 +462,15 @@ class Router[E: Event]:
The check for results from the passed future allows us to know when in fact a
valid frame has finished, and a resubmission may be on the table.
'''
result = None
if not future.cancelled():
result = future.result()
if not result: return
else:
return None
# result should be *something* if work was scheduled
if not result:
return None
self.event_log.append((event, result))
queued_callbacks = self.stop_event(event)
@ -451,6 +485,29 @@ class Router[E: Event]:
def event_index(self, event):
return event[:2]
def shutdown(self):
logger.info(color_text('Router shutdown received', Fore.BLACK, Back.RED))
self.should_exit = True
for future in tqdm(
list(self._active_futures),
desc=color_text('Cancelling active futures...', Fore.BLACK, Back.RED),
colour='red',
):
future.cancel()
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=False)
def general_task_done(self, future):
self._active_futures.remove(future)
try:
if not future.cancelled():
future.result()
except Exception as e:
logger.error(f"Exception occurred in threaded task: '{e}'")
#traceback.print_exc()
class ChainRouter[E: Event](Router[E]):
'''
@ -545,14 +602,6 @@ class ChainRouter[E: Event](Router[E]):
return listener
def handle_exception(future):
try:
future.result()
except Exception as e:
print(f"Exception occurred: {e}")
traceback.print_exc()
# RouterBuilder
def route(router, route_group, **route_kwargs):
def decorator(f):
@ -666,7 +715,9 @@ class RouterBuilder(ChainRouter, metaclass=RouteRegistryMeta):
# assumed no kwargs for passthrough
if route_group == 'post':
for method, _ in method_arg_list:
router.add_post_callback(method)
router.add_post_callback(
update_wrapper(partial(method, self), method),
)
continue
group_options = router_options.get(route_group)

View File

@ -15,7 +15,6 @@ listeners in one place.
make the Server definition more flexible.
'''
import re
import signal
import asyncio
import logging
import threading
@ -75,6 +74,8 @@ class Server:
self.server_text = ''
self.server_args = {}
self.started = False
self.loop = None
self._server_setup()
@ -228,6 +229,8 @@ class Server:
if not listener.started:
listener.start()
self.started = False
if self.server:
logger.info(f'Server{self.server_text} @ http://{self.host}:{self.port}')
@ -292,7 +295,6 @@ class Server:
self.loop.call_soon_threadsafe(set_should_exit)
class ListenerServer:
'''
Server abstraction to handle disparate listeners.
@ -305,16 +307,16 @@ class ListenerServer:
managed_listeners = []
self.managed_listeners = managed_listeners
self.started = False
def start(self):
signal.signal(signal.SIGINT, lambda s,f: self.shutdown())
signal.signal(signal.SIGTERM, lambda s,f: self.shutdown())
for listener in self.managed_listeners:
#loop.run_in_executor(None, partial(self.listener.start, loop=loop))
if not listener.started:
listener.start()
self.started = True
for listener in self.managed_listeners:
listener.join()

View File

@ -1,12 +1,21 @@
import logging
from pathlib import Path
from concurrent.futures import as_completed
from tqdm import tqdm
from colorama import Fore, Back, Style
from inotify_simple import flags as iflags
from co3.resources import DiskResource
from co3 import Differ, Syncer, Database
from execlog.event import Event
from execlog.routers import PathRouter
from execlog.util.generic import color_text
logger = logging.getLogger(__name__)
class PathDiffer(Differ[Path]):
def __init__(
self,
@ -101,16 +110,18 @@ class PathRouterSyncer(Syncer[Path]):
'''
return [
self._construct_event(str(path), endpoint, iflags.MODIFY)
for endpoint, _ in path_tuples[1]
for endpoint, _ in path_tuples[0]
]
def filter_diff_sets(self, l_excl, r_excl, lr_int):
total_disk_files = len(l_excl) + len(lr_int)
total_joint_files = len(lr_int)
def file_out_of_sync(p):
db_el, disk_el = lr_int[p]
_, db_el = lr_int[p]
db_mtime = float(db_el[0].get('mtime','0'))
disk_mtime = File(p, disk_el[0]).mtime
disk_mtime = Path(p).stat().st_mtime
return disk_mtime > db_mtime
lr_int = {p:v for p,v in lr_int.items() if file_out_of_sync(p)}
@ -119,10 +130,10 @@ class PathRouterSyncer(Syncer[Path]):
oos_count = len(l_excl) + len(lr_int)
oos_prcnt = oos_count / max(total_disk_files, 1) * 100
logger.info(color_text(Fore.GREEN, f'{len(l_excl)} new files to add'))
logger.info(color_text(Fore.YELLOW, f'{len(lr_int)} modified files'))
logger.info(color_text(Fore.RED, f'{len(r_excl)} files to remove'))
logger.info(color_text(Style.DIM, f'({oos_prcnt:.2f}%) of disk files out-of-sync'))
logger.info(color_text(f'{len(l_excl)} new files to add', Fore.GREEN)),
logger.info(color_text(f'{len(lr_int)} modified files [{total_joint_files} up-to-date]', Fore.YELLOW)),
logger.info(color_text(f'{len(r_excl)} files to remove', Fore.RED)),
logger.info(color_text(f'({oos_prcnt:.2f}%) of disk files out-of-sync', Style.DIM)),
return l_excl, r_excl, lr_int
@ -137,12 +148,18 @@ class PathRouterSyncer(Syncer[Path]):
results = []
for future in tqdm(
as_completed(event_futures),
total=chunk_size,
desc=f'Awaiting chunk futures [submitted {len(event_futures)}/{chunk_size}]'
total=len(event_futures),
desc=f'Awaiting chunk futures [submitted {len(event_futures)}]'
):
try:
if not future.cancelled():
results.append(future.result())
except Exception as e:
logger.warning(f"Sync job failed with exception {e}")
return results
def shutdown(self):
super().shutdown()
self.router.shutdown()

View File

@ -39,8 +39,8 @@ class ColorFormatter(logging.Formatter):
formatter = self.FORMATS[submodule]
name = record.name
if package == 'localsys':
name = f'localsys.{subsubmodule}'
if package == 'execlog':
name = f'execlog.{subsubmodule}'
limit = 26
name = name[:limit]