Files
trainlib/trainlib/utils/job.py
2026-03-03 18:11:37 -08:00

55 lines
1.4 KiB
Python

import logging
import concurrent
from concurrent.futures import Future, as_completed
from tqdm import tqdm
from colorama import Fore, Style
from mema.util.text import color_text
logger: logging.Logger = logging.getLogger(__name__)
def process_futures(
futures: list[Future],
desc: str | None = None,
unit: str | None = None,
) -> None:
if desc is None:
desc = "Awaiting futures"
if unit is None:
unit = "it"
success = 0
cancelled = 0
errored = 0
submitted = len(futures)
progress_bar = tqdm(
total=len(futures),
desc=f"{desc} [submitted {len(futures)}]",
unit=unit,
)
for future in as_completed(futures):
try:
future.result()
success += 1
except concurrent.futures.CancelledError as e:
cancelled += 1
logger.error(f'Future cancelled; "{e}"')
except Exception as e:
errored += 1
logger.warning(f'Future failed with unknown exception "{e}"')
suc_txt = color_text(f"{success}", Fore.GREEN)
can_txt = color_text(f"{cancelled}", Fore.YELLOW)
err_txt = color_text(f"{errored}", Fore.RED)
tot_txt = color_text(f"{success+cancelled+errored}", Style.BRIGHT)
progress_bar.set_description(
f"{desc} [{tot_txt} / {submitted} | {suc_txt} {can_txt} {err_txt}]"
)
progress_bar.update(n=1)
progress_bar.close()