From 4a18b22521ed370a24bccaed38d12d14b4e2dd0f Mon Sep 17 00:00:00 2001 From: Elad Venezian Date: Tue, 4 Mar 2025 09:33:18 +0200 Subject: [PATCH 01/28] succed to creat simple server and interact with it Add auth endpoint Working with Single server Succeed multi at simple scenario Works for case where there are more batches than workers Mimic server error Auto test, and store in cache. Add shotedown mechanisim Satart working on CCC Signed-off-by: Elad Venezian --- ccc_worker_server.py | 123 ++++++++++++++++++++++++ src/unitxt/inference.py | 208 +++++++++++++++++++++++++++++++++++++++- test_caching_ccc.py | 64 +++++++++++++ test_caching_try.py | 112 ++++++++++++++++++++++ utils/.secrets.baseline | 54 ++++------- 5 files changed, 520 insertions(+), 41 deletions(-) create mode 100644 ccc_worker_server.py create mode 100644 test_caching_ccc.py create mode 100644 test_caching_try.py diff --git a/ccc_worker_server.py b/ccc_worker_server.py new file mode 100644 index 0000000000..1ef71d0f1b --- /dev/null +++ b/ccc_worker_server.py @@ -0,0 +1,123 @@ + +import logging +import os +import random +import sys +import threading +import time + +import requests +from flask import Flask, jsonify, request +from unitxt.inference import HFPipelineBasedInferenceEngine + +logging.basicConfig(level=logging.INFO) + +app = Flask(__name__) +PORT = None + +class Server: + def __init__(self): + self.inference_engine = None + self.inactivity_timeout = 600 + self.monitor_thread = threading.Thread(target=self.monitor_activity, daemon=True) + self.last_request_time = time.time() + self.shutdown_flag = False + self.monitor_thread.start() + + def update_last_request_time(self): + self.last_request_time = time.time() + + def monitor_activity(self): + while not self.shutdown_flag: + time.sleep(5) + if time.time() - self.last_request_time > self.inactivity_timeout: + app.logger.info(f"No requests for {self.inactivity_timeout} seconds. Shutting down server...") + try: + requests.post(f"http://localhost:{PORT}/shutdown", timeout=5) + except Exception: + pass + else: + app.logger.info( + f"{int(self.inactivity_timeout - (time.time() - self.last_request_time))} till shutdown...") + + def shutdown_server(self): + self.shutdown_flag = True + app.logger.info("Server shutting down...") + shutdown_func = request.environ.get("werkzeug.server.shutdown") + if shutdown_func: + shutdown_func() + # Allow the shutdown process to complete, then force exit the program + time.sleep(1) + os._exit(0) # This immediately stops the program + + def init_server(self, **kwargs): + kwargs["use_cache"] =True + self.inference_engine = HFPipelineBasedInferenceEngine(**kwargs) + + def infer(self, **kwargs): + inputs = [] + return self.inference_engine(inputs) + + +server = Server() + +@app.before_request +def update_activity(): + server.update_last_request_time() + + +@app.route("/shutdown", methods=["POST"]) +def shutdown(): + app.logger.info("Received shutdown request") + server.shutdown_server() + return jsonify({"message": "Shutting down server..."}), 200 + + +@app.route("/init_server", methods=["POST"]) +def init_server(): + kwargs = request.get_json() + server.init_server(**kwargs) + return jsonify("Accepted") + + +@app.route("//v1/chat/completions", methods=["POST"]) +@app.route("///v1/chat/completions", methods=["POST"]) +def completions(model: str, model_prefix: str = "None"): + if random.random() < 0: + logging.error("Bad luck! Returning 500 with an error message.") + app.logger.info("Server shutting down...") + shutdown_func = request.environ.get("werkzeug.server.shutdown") + if shutdown_func: + shutdown_func() + # Allow the shutdown process to complete, then force exit the program + time.sleep(1) + os._exit(0) # This immediately stops the program + return jsonify({"error": "Bad luck, something went wrong!"}), 500 + + body = request.get_json() + # validate that request parameters are equal to the model config. Print warnings if not. + for k, v in body.items(): + if k == "messages": + continue + k = "model_name" if k == "model" else k + attr = getattr(server.inference_engine, k, None) + if attr is None: + logging.warning(f"Warning: {k} is not an attribute in inference_engine") + else: + if attr != v: + logging.warning(f"Warning: {k} value in boody({v}) is different from value in inference engine ({attr})") + texts = [{"source": m[0]["content"]} for m in body["messages"]] + predictions = server.inference_engine(texts) + return jsonify({ + "choices": [{"message": {"role": "assistant","content": p}} for p in predictions], + }) + + +@app.route("/status", methods=["GET"]) +def status(): + return "up", 200 + + +if __name__ == "__main__": + PORT = sys.argv[1] + app.run(host="0.0.0.0", port=PORT, debug=True) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index fae694f708..5a8a850815 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -7,11 +7,14 @@ import json import logging import os +import random import re import sys +import threading import time import uuid from collections import Counter +from concurrent.futures import Future, ThreadPoolExecutor, wait from datetime import datetime from itertools import islice from multiprocessing.pool import ThreadPool @@ -30,6 +33,7 @@ Union, ) +import requests from datasets import Dataset, DatasetDict, Image from tqdm import tqdm, trange from tqdm.asyncio import tqdm_asyncio @@ -276,7 +280,7 @@ def infer( if prediction is None: continue cache_key = self._get_cache_key(item) - self._cache[cache_key] = prediction + self.store_in_cache(cache_key, prediction) else: inferred_results = [] # Combine cached and inferred results in original order @@ -286,6 +290,9 @@ def infer( result.extend(batch_predictions) else: result = self._infer(dataset, return_meta_data) + + result = self.post_process_results(result) + return ListWithMetadata( result, metadata={ @@ -295,6 +302,12 @@ def infer( }, ) + def store_in_cache(self, cache_key, prediction): + self._cache[cache_key] = prediction + + def post_process_results(self, result): + return result + def _mock_infer( self, dataset: Union[List[Dict[str, Any]], Dataset], @@ -1957,7 +1970,7 @@ def prepare_engine(self): @staticmethod def get_base_url_from_model_name(model_name: str): base_url_template = ( - "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/{}" + "http://localhost:5000/{}" ) return base_url_template.format( RITSInferenceEngine._get_model_name_for_endpoint(model_name) @@ -3546,10 +3559,9 @@ def _infer( dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - if return_meta_data and not hasattr(self.engine, "get_return_object"): + if return_meta_data: raise NotImplementedError( - f"Inference engine {self.engine.__class__.__name__} does not support return_meta_data as it " - f"does not contain a 'get_return_object' method. Please set return_meta_data=False." + f"Inference engine {self.engine.__class__.__name__} does not support return_meta_data." ) inputs = [] @@ -3576,3 +3588,189 @@ def _infer( predictions.append(options_scores.most_common(1)[0][0]) return predictions + + +class MultiServersInferenceEngine(OpenAiInferenceEngine, + HFGenerationParamsMixin): + + workers_url: List[str] + + def post_server(self, server_url, endpoint, data): + headers = {"Content-Type": "application/json"} + response = requests.post(url=f"{server_url}/{endpoint}", json=data, headers=headers) + response.raise_for_status() + return response.json() + + def prepare_engine(self): + from openai import OpenAI + self.lock = threading.Lock() + self.workers_state = {} + credentials = self._prepare_credentials() + for url in self.workers_url: + init_result = self.post_server(endpoint="init_server",server_url=url, + data={**self.to_dict([HFGenerationParamsMixin]), **{"model_name": self.model_name}}) + if init_result == "Accepted": + self.add_worker(url, client=OpenAI( + api_key=credentials["api_key"], + base_url= f"{url}/{self.model_name}" + "/v1", + default_headers=self.get_default_headers(), + )) + + #def init_server_and_add_to_workers_list + + + def add_worker(self, url, client): + with self.lock: + self.workers_state[url] = {"status": "ready", "client": client} + + def release_worker(self, url): + with self.lock: + self.workers_state[url]["status"] = "ready" + + def assign_worker(self): + with self.lock: + while True: + # print("trying to assign worker...") + for url, rec in self.workers_state.items(): + if rec["status"] == "ready": + rec["status"] ="assigned" + return url, rec["client"] + time.sleep(random.uniform(0, 1)) + + def _prepare_credentials(self) -> CredentialsOpenAi: + return {"api" + "_" + "key": "no-api-key",} + + def _infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> List[Any]: # Now returns a Future object + """Runs inference in parallel, returning futures for each batch.""" + # Lazy-initialize executor if not already created + if not hasattr(self, "_executor"): + self._executor = ThreadPoolExecutor(max_workers=len(self.workers_state)) + + # Submit the batch job + batch_future = self._executor.submit(self._run_batch, dataset, return_meta_data) + + # Create individual futures that resolve when batch_future is done + element_futures = [Future() for _ in dataset] + + def set_results(batch_fut: Future): + """Callback to set individual results once batch computation is done.""" + try: + results = batch_fut.result() # Get the batch results + for i, res in enumerate(results): + element_futures[i].set_result(res) # Set each individual future + except Exception as e: + for f in element_futures: + f.set_exception(e) # Propagate any exception + + # Attach the callback to the batch future + batch_future.add_done_callback(set_results) + + return element_futures # Return a list of futures + + def _run_batch(self, batch, return_meta_data): + """Helper function to process a batch inside a thread.""" + logger.info(f"Trying to get assigned: {self.workers_state}") + url, client = self.assign_worker() + logger.info(f"Thread {url} processing batch: {self.workers_state}") + messages = [self.to_messages(instance) for instance in batch] + logger.info(f"a {url}") + try: + response = client.chat.completions.create( + messages=messages, + model=self.model_name, + **self._get_completion_kwargs(), + ) + logger.info(f"response: {response}") + predictions = [r.message.content for r in response.choices] + result = [self.get_return_object(p, response, return_meta_data) for p in predictions] + finally: + logger.info(f"Thread {url} release state:") + self.release_worker(url) + logger.info(f"Thread {url} release state done: {self.workers_state}") + return result + + def post_process_results(self, result): + futures = [r for r in result if isinstance(r, Future)] + if futures: + wait(futures) + + return [r.result() if isinstance(r, Future) else r for r in result] + + def store_in_cache(self, cache_key, prediction): + if isinstance(prediction, Future): + def store_after_pack_in_cache(future, cache_key): + prediction = future.result() + if prediction is not None: + self._cache[cache_key] = prediction + + prediction.add_done_callback(lambda f, key=cache_key: store_after_pack_in_cache(f, key)) + else: + self._cache[cache_key] = prediction + + +class CCCInferenceEngine(MultiServersInferenceEngine): + ccc_host: str + ccc_user: str + ccc_path: str + ccc_python: str + server_port: str = "5000" + num_of_workers: int = 5 + workers_url: List[str] = [] + + def prepare_engine(self): + assert not self.workers_url, "CCCInferenceEngine doesn't support explicit setting of workers_url" + self.start_ccc_servers() + self.prepare_engine() + + def start_ccc_servers(self): + import paramiko + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.ccc_host, username=self.ccc_user) + ssh.exec_command(f"mkdir -p {self.ccc_path}") + self.ccc_jobs = {} + for i in range(self.num_of_workers): + command = f"bash -l -c 'jbsub -queue x86_6h -cores 4+1 -require v100 -mem 24G -out ~/server{i}.log {self.ccc_python} /dccstor/fuse/unitxt/ccc_worker_server.py {self.server_port}'" + stdin, stdout, stderr = ssh.exec_command(command) + job_output = stdout.read().decode().strip() + job_error = stderr.read().decode().strip() + match = re.search(r"Job <(\d+)> is submitted", job_output) + if match: + job_id = match.group(1) + logger.info(f"Start job ID: {job_id}") + self.ccc_jobs[job_id] ={"status": "AVAIL", "log_id": i} + else: + raise RuntimeError(f"Failed to run jbsub on host {self.ccc_host}.\nstdout: {job_output}.\nstderr: {job_error}") + + def run_monitor_ccc_jobs(ssh, sample_every): + while True: + command = "bash -l -c 'jbinfo'" + stdin, stdout, stderr = ssh.exec_command(command) + output = stdout.read().decode().strip() + #error = stderr.read().decode().strip() + for job_id in self.ccc_jobs.keys(): + match = re.search(rf"^{job_id}\s+\S+\s+(\w+)", output, re.MULTILINE) + if match: + status = match.group(1) + if status != self.ccc_jobs[job_id]["status"]: + if self.ccc_jobs[job_id]["status"] == "RUN": + pass # add server to server list + elif status == "RUN": + pass # remove server from server list. Consider fetching the server log. + self.ccc_jobs[job_id]["status"] = status + logger.info(f"status has been changed: {job_id} - {status}") + + + time.sleep(sample_every) + + thread = threading.Thread(target=run_monitor_ccc_jobs, args=(ssh, 10)) + thread.daemon = True # + thread.start() + + + time.sleep(200) # This keeps the main thread alive so the background thread can continue + diff --git a/test_caching_ccc.py b/test_caching_ccc.py new file mode 100644 index 0000000000..c77710b676 --- /dev/null +++ b/test_caching_ccc.py @@ -0,0 +1,64 @@ +import hashlib +import json +import logging +import os +import time + +import joblib +import unitxt +from unitxt import load_dataset +from unitxt.inference import CCCInferenceEngine +from unitxt.logging_utils import set_verbosity + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + +logger = logging.getLogger(__name__) + + +def get_cache_filename(cache_dir="cache", **kwargs): + """Generate a unique filename for caching based on function arguments.""" + os.makedirs(cache_dir, exist_ok=True) + hash_key = hashlib.md5(json.dumps(kwargs, sort_keys=True).encode()).hexdigest() + return os.path.join(cache_dir, f"dataset_{hash_key}.pkl") + + +def load_dataset_cached(**kwargs): + """Load dataset with disk caching.""" + cache_file = get_cache_filename(**kwargs) + + if os.path.exists(cache_file): + return joblib.load(cache_file) + + data = load_dataset(**kwargs) + joblib.dump(data, cache_file) + return data + + +if __name__ == "__main__": + set_verbosity("debug") + unitxt.settings.allow_unverified_code = True + dataset = load_dataset_cached(card="cards.openbook_qa", + split="test") + + dataset = dataset.select(range(100)) + + inference_model = CCCInferenceEngine( + model_name="google/flan-t5-small", + temperature=0.202, + max_new_tokens=256, + use_cache=True, + cache_batch_size=5, + ccc_host="cccxl013.pok.ibm.com", + ccc_user="eladv", + ccc_path="", + ccc_python="/dccstor/fuse/eladv_envs/unitxt/bin/python", + num_of_workers=3, + ) + + start_time = time.time() + predictions = inference_model.infer(dataset) + end_time = time.time() + + logger.info(f"predictions contains {predictions.count(None)} Nones") + for p in predictions: + logger.info(f"prediction: {p}") diff --git a/test_caching_try.py b/test_caching_try.py new file mode 100644 index 0000000000..a851ad46ce --- /dev/null +++ b/test_caching_try.py @@ -0,0 +1,112 @@ +import hashlib +import json +import logging +import os +import subprocess +import time + +import joblib +import requests +import unitxt +from unitxt import load_dataset +from unitxt.inference import MultiServersInferenceEngine +from unitxt.logging_utils import set_verbosity + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + +logger = logging.getLogger(__name__) + + +def get_cache_filename(cache_dir="cache", **kwargs): + """Generate a unique filename for caching based on function arguments.""" + os.makedirs(cache_dir, exist_ok=True) + hash_key = hashlib.md5(json.dumps(kwargs, sort_keys=True).encode()).hexdigest() + return os.path.join(cache_dir, f"dataset_{hash_key}.pkl") + + +def load_dataset_cached(**kwargs): + """Load dataset with disk caching.""" + cache_file = get_cache_filename(**kwargs) + + if os.path.exists(cache_file): + return joblib.load(cache_file) + + data = load_dataset(**kwargs) + joblib.dump(data, cache_file) + return data + + +def run_worker_in_a_port(port): + kill_process_on_port(port) + process = subprocess.Popen( + ["/Users/eladv/miniforge3/envs/unitxt/bin/python", "/Users/eladv/unitxt/ccc_worker_server.py", f"{port}"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True) + logger.info(f"Started worker on port {port} with PID {process.pid}") + return process + +def kill_process_on_port(port): + try: + output = subprocess.check_output(f"lsof -ti:{port}", shell=True).decode().strip() + if output: + logger.info(f"Killing process {output} on port {port}...") + subprocess.run(f"kill -9 {output}", shell=True) + except subprocess.CalledProcessError: + pass + + +def is_up(server_url): + try: + response = requests.get(f"{server_url}/status", timeout=5) + return response.text.strip().lower() == "up" + except requests.RequestException: + return False + +def set_up_worker_servers(servers): + ports = [5000, 5001, 5002, 5003, 5004] + processes = [(port, run_worker_in_a_port(port)) for port in ports] + while len(processes) > 0: + for port, process in processes: + # Check if the process is still running + if process.poll() is not None: # If poll() returns None, the process is still running + stdout, stderr = process.communicate() # Get the output and errors + logger.error(f"Process on port {port} has stopped!") + logger.error(f"STDOUT:\n{stdout}") + logger.error(f"STDERR:\n{stderr}") + raise RuntimeError(f"Failed to Start server on port: {port}") + if is_up(f"http://localhost:{port}"): + processes.remove((port, process)) + time.sleep(0.3) + logger.info(f"The following servers still need to start: {[p[0] for p in processes]}") + + +if __name__ == "__main__": + #ports = [5000,5001,5002,5003,5004] + #servers = [f"http://localhost:{port}" for port in ports] + hosts = ["cccxc425","cccxc417"]# ,'cccxc436'] + servers = [f"http://{server}.pok.ibm.com:5000" for server in hosts] + #set_up_worker_servers(servers) + set_verbosity("debug") + unitxt.settings.allow_unverified_code = True + dataset = load_dataset_cached(card="cards.openbook_qa", + split="test") + + dataset = dataset.select(range(100)) + + inference_model = MultiServersInferenceEngine( + model_name="google/flan-t5-small", + temperature=0.202, + max_new_tokens=256, + use_cache=True, + cache_batch_size=5, + workers_url=servers + ) + + start_time = time.time() + predictions = inference_model.infer(dataset) + end_time = time.time() + + logger.info(f"predictions contains {predictions.count(None)} Nones") + for p in predictions: + logger.info(f"prediction: {p}") diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index db6e6ed1e5..7fd3857fcf 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -127,58 +127,40 @@ } ], "results": { - "src/unitxt/catalog/cards/tablebench.json": [ + "src/unitxt/loaders.py": [ { - "type": "Hex High Entropy String", - "filename": "src/unitxt/catalog/cards/tablebench.json", - "hashed_secret": "fab1cac10b07d605b4ab506d750364b1327c8367", - "is_verified": false, - "line_number": 6 - } - ], - "src/unitxt/catalog/cards/tablebench_data_analysis.json": [ - { - "type": "Hex High Entropy String", - "filename": "src/unitxt/catalog/cards/tablebench_data_analysis.json", - "hashed_secret": "fab1cac10b07d605b4ab506d750364b1327c8367", - "is_verified": false, - "line_number": 6 - } - ], - "src/unitxt/catalog/cards/tablebench_fact_checking.json": [ - { - "type": "Hex High Entropy String", - "filename": "src/unitxt/catalog/cards/tablebench_fact_checking.json", - "hashed_secret": "fab1cac10b07d605b4ab506d750364b1327c8367", + "type": "Secret Keyword", + "filename": "src/unitxt/loaders.py", + "hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742", "is_verified": false, - "line_number": 6 + "line_number": 602 } ], - "src/unitxt/catalog/cards/tablebench_numerical_reasoning.json": [ + "src/unitxt/metrics.py": [ { "type": "Hex High Entropy String", - "filename": "src/unitxt/catalog/cards/tablebench_numerical_reasoning.json", - "hashed_secret": "fab1cac10b07d605b4ab506d750364b1327c8367", + "filename": "src/unitxt/metrics.py", + "hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889", "is_verified": false, - "line_number": 6 + "line_number": 71 } ], - "utils/.secrets.baseline": [ + "tests/library/test_loaders.py": [ { - "type": "Hex High Entropy String", - "filename": "utils/.secrets.baseline", - "hashed_secret": "ec3d2a39003a5b76da13be9b51af01cf6d616efb", + "type": "Secret Keyword", + "filename": "tests/library/test_loaders.py", + "hashed_secret": "8d814baafe5d8412572dc520dcab83f60ce1375c", "is_verified": false, - "line_number": 134 + "line_number": 125 }, { "type": "Secret Keyword", - "filename": "utils/.secrets.baseline", - "hashed_secret": "ec3d2a39003a5b76da13be9b51af01cf6d616efb", + "filename": "tests/library/test_loaders.py", + "hashed_secret": "42a472ac88cd8d43a2c5ae0bd0bdf4626cdaba31", "is_verified": false, - "line_number": 134 + "line_number": 135 } ] }, - "generated_at": "2025-04-02T14:52:32Z" + "generated_at": "2025-03-11T11:51:35Z" } From 9504d0c6cf19e960ebd66a5a29118a8e8f01e04d Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Mon, 28 Apr 2025 12:03:04 +0300 Subject: [PATCH 02/28] minor changes to MultiServersInferenceEngine --- src/unitxt/inference.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 5a8a850815..03651dd609 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3590,12 +3590,13 @@ def _infer( return predictions -class MultiServersInferenceEngine(OpenAiInferenceEngine, - HFGenerationParamsMixin): +ParamDataClass = TypeVar("ParamDataClass") +class MultiServersInferenceEngine(OpenAiInferenceEngine, ParamDataClass): workers_url: List[str] - def post_server(self, server_url, endpoint, data): + @staticmethod + def post_server(server_url: str, endpoint:str , data: Dict) -> str: headers = {"Content-Type": "application/json"} response = requests.post(url=f"{server_url}/{endpoint}", json=data, headers=headers) response.raise_for_status() @@ -3606,24 +3607,29 @@ def prepare_engine(self): self.lock = threading.Lock() self.workers_state = {} credentials = self._prepare_credentials() + assert len(self.workers_url) > 0, "No workers_url are set." for url in self.workers_url: - init_result = self.post_server(endpoint="init_server",server_url=url, - data={**self.to_dict([HFGenerationParamsMixin]), **{"model_name": self.model_name}}) + init_result = self.post_server(endpoint="init_server", + server_url=url, + data={**self.to_dict([ParamDataClass]), + **{"model_name": self.model_name}}) if init_result == "Accepted": self.add_worker(url, client=OpenAI( api_key=credentials["api_key"], base_url= f"{url}/{self.model_name}" + "/v1", default_headers=self.get_default_headers(), )) + else: + raise RuntimeError(f"worker_url ({url}/{self.model_name}) initialization failed: {init_result}") #def init_server_and_add_to_workers_list - def add_worker(self, url, client): + def add_worker(self, url: str, client) -> None: with self.lock: self.workers_state[url] = {"status": "ready", "client": client} - def release_worker(self, url): + def release_worker(self, url: str): with self.lock: self.workers_state[url]["status"] = "ready" From 01e3603beae2cdda1811bb21310ecef20215512e Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Mon, 28 Apr 2025 16:25:21 +0300 Subject: [PATCH 03/28] minor changes to MultiServersInferenceEngine --- src/unitxt/inference.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 03651dd609..165be762c8 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3591,7 +3591,7 @@ def _infer( ParamDataClass = TypeVar("ParamDataClass") -class MultiServersInferenceEngine(OpenAiInferenceEngine, ParamDataClass): +class MultiServersInferenceEngine(OpenAiInferenceEngine): workers_url: List[str] @@ -3732,12 +3732,14 @@ def prepare_engine(self): self.start_ccc_servers() self.prepare_engine() + def start_ccc_servers(self): import paramiko ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect(self.ccc_host, username=self.ccc_user) ssh.exec_command(f"mkdir -p {self.ccc_path}") + self.ccc_jobs = {} for i in range(self.num_of_workers): command = f"bash -l -c 'jbsub -queue x86_6h -cores 4+1 -require v100 -mem 24G -out ~/server{i}.log {self.ccc_python} /dccstor/fuse/unitxt/ccc_worker_server.py {self.server_port}'" From d028e42b90f4d806b6a78803ed8b432a203c2187 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Mon, 28 Apr 2025 21:01:35 +0300 Subject: [PATCH 04/28] updates --- ccc_worker_server.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ccc_worker_server.py b/ccc_worker_server.py index 1ef71d0f1b..7448f0e1c8 100644 --- a/ccc_worker_server.py +++ b/ccc_worker_server.py @@ -2,6 +2,7 @@ import logging import os import random +import socket import sys import threading import time @@ -121,3 +122,9 @@ def status(): if __name__ == "__main__": PORT = sys.argv[1] app.run(host="0.0.0.0", port=PORT, debug=True) + + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + logging.INFO(f"Server hostname: {hostname}") + logging.INFO(f"Server IP address (may be 127.0.0.1 if accessed locally): {ip_address}") + From e4eecf3737f66d1eb19e21e706e953471966f31a Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Mon, 28 Apr 2025 21:08:10 +0300 Subject: [PATCH 05/28] updates --- ccc_worker_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ccc_worker_server.py b/ccc_worker_server.py index 7448f0e1c8..042e62e83f 100644 --- a/ccc_worker_server.py +++ b/ccc_worker_server.py @@ -125,6 +125,6 @@ def status(): hostname = socket.gethostname() ip_address = socket.gethostbyname(hostname) - logging.INFO(f"Server hostname: {hostname}") - logging.INFO(f"Server IP address (may be 127.0.0.1 if accessed locally): {ip_address}") + logging.CRITICAL(f"Server hostname: {hostname}") + logging.CRITICAL(f"Server IP address (may be 127.0.0.1 if accessed locally): {ip_address}") From a0176ed9c4204dcaab143e91f76d9ac9e90e8477 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Mon, 28 Apr 2025 21:09:19 +0300 Subject: [PATCH 06/28] updates --- ccc_worker_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ccc_worker_server.py b/ccc_worker_server.py index 042e62e83f..86b456879e 100644 --- a/ccc_worker_server.py +++ b/ccc_worker_server.py @@ -125,6 +125,6 @@ def status(): hostname = socket.gethostname() ip_address = socket.gethostbyname(hostname) - logging.CRITICAL(f"Server hostname: {hostname}") - logging.CRITICAL(f"Server IP address (may be 127.0.0.1 if accessed locally): {ip_address}") + app.logger(f"Server hostname: {hostname}") + app.logger(f"Server IP address (may be 127.0.0.1 if accessed locally): {ip_address}") From b574f5720842d5ecb07f9518cb6f70e74927d5bd Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Mon, 28 Apr 2025 21:11:43 +0300 Subject: [PATCH 07/28] updates --- ccc_worker_server.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ccc_worker_server.py b/ccc_worker_server.py index 86b456879e..49b1e2d153 100644 --- a/ccc_worker_server.py +++ b/ccc_worker_server.py @@ -40,6 +40,10 @@ def monitor_activity(self): else: app.logger.info( f"{int(self.inactivity_timeout - (time.time() - self.last_request_time))} till shutdown...") + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + app.logger.info(f"Server hostname: {hostname}") + app.logger.info(f"Server IP address (may be 127.0.0.1 if accessed locally): {ip_address}") def shutdown_server(self): self.shutdown_flag = True @@ -123,8 +127,5 @@ def status(): PORT = sys.argv[1] app.run(host="0.0.0.0", port=PORT, debug=True) - hostname = socket.gethostname() - ip_address = socket.gethostbyname(hostname) - app.logger(f"Server hostname: {hostname}") - app.logger(f"Server IP address (may be 127.0.0.1 if accessed locally): {ip_address}") + From 004d8518c3ac4da53ed36dba113343bc24b33f0d Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Mon, 28 Apr 2025 21:19:24 +0300 Subject: [PATCH 08/28] updates --- ccc_worker_server.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ccc_worker_server.py b/ccc_worker_server.py index 49b1e2d153..1fb1e6f945 100644 --- a/ccc_worker_server.py +++ b/ccc_worker_server.py @@ -25,6 +25,10 @@ def __init__(self): self.shutdown_flag = False self.monitor_thread.start() + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + app.logger.info(f"*** Server IP address: '{ip_address}/{PORT}' ****") + def update_last_request_time(self): self.last_request_time = time.time() @@ -40,10 +44,7 @@ def monitor_activity(self): else: app.logger.info( f"{int(self.inactivity_timeout - (time.time() - self.last_request_time))} till shutdown...") - hostname = socket.gethostname() - ip_address = socket.gethostbyname(hostname) - app.logger.info(f"Server hostname: {hostname}") - app.logger.info(f"Server IP address (may be 127.0.0.1 if accessed locally): {ip_address}") + def shutdown_server(self): self.shutdown_flag = True From 7467c7afea5adea65d7af76562e652bd0de8dc99 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 11:54:30 +0300 Subject: [PATCH 09/28] add the inference server as a cmd line --- pyproject.toml | 1 + .../unitxt/service/inference_server.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) rename ccc_worker_server.py => src/unitxt/service/inference_server.py (91%) diff --git a/pyproject.toml b/pyproject.toml index 06dc369582..5ed8201a2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,7 @@ unitxt-explore = "unitxt.ui:launch" unitxt-assistant = "unitxt.assistant:launch" unitxt-metrics-service = "unitxt.service.metrics.main:start_metrics_http_service" unitxt-evaluate = "unitxt.evaluate_cli:main" +unitxt-inference-server = "unitxt.service.inference_server:main" [tool.ruff] exclude = [ diff --git a/ccc_worker_server.py b/src/unitxt/service/inference_server.py similarity index 91% rename from ccc_worker_server.py rename to src/unitxt/service/inference_server.py index 1fb1e6f945..62734248ad 100644 --- a/ccc_worker_server.py +++ b/src/unitxt/service/inference_server.py @@ -1,15 +1,15 @@ - +import argparse import logging import os import random import socket -import sys import threading import time import requests from flask import Flask, jsonify, request -from unitxt.inference import HFPipelineBasedInferenceEngine + +from ..inference import HFPipelineBasedInferenceEngine logging.basicConfig(level=logging.INFO) @@ -27,7 +27,7 @@ def __init__(self): hostname = socket.gethostname() ip_address = socket.gethostbyname(hostname) - app.logger.info(f"*** Server IP address: '{ip_address}/{PORT}' ****") + app.logger.info(f"server_ip={ip_address} server_port={PORT}") def update_last_request_time(self): self.last_request_time = time.time() @@ -125,8 +125,10 @@ def status(): if __name__ == "__main__": - PORT = sys.argv[1] - app.run(host="0.0.0.0", port=PORT, debug=True) + parser = argparse.ArgumentParser(prog="unitxt inference worker server") + parser.add_argument("port", type=int, help="Port to run the server on", default=8080) + args = parser.parse_args() + app.run(host="0.0.0.0", port=args.port, debug=True) From 3b2adae459165e62bc2d0baaa57ccccf9915cd39 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 13:08:27 +0300 Subject: [PATCH 10/28] update --- src/unitxt/inference.py | 233 ++++++++++++++++++++++++++++------------ 1 file changed, 167 insertions(+), 66 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 165be762c8..f12dfad665 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1,5 +1,6 @@ import abc import asyncio +import atexit import base64 import dataclasses import hashlib @@ -7,8 +8,8 @@ import json import logging import os -import random import re +import signal import sys import threading import time @@ -3593,7 +3594,7 @@ def _infer( ParamDataClass = TypeVar("ParamDataClass") class MultiServersInferenceEngine(OpenAiInferenceEngine): - workers_url: List[str] + workers_url: List[str] = [] @staticmethod def post_server(server_url: str, endpoint:str , data: Dict) -> str: @@ -3603,29 +3604,31 @@ def post_server(server_url: str, endpoint:str , data: Dict) -> str: return response.json() def prepare_engine(self): - from openai import OpenAI + self._register_cleanup_handlers() self.lock = threading.Lock() self.workers_state = {} - credentials = self._prepare_credentials() - assert len(self.workers_url) > 0, "No workers_url are set." + for url in self.workers_url: - init_result = self.post_server(endpoint="init_server", - server_url=url, - data={**self.to_dict([ParamDataClass]), - **{"model_name": self.model_name}}) - if init_result == "Accepted": - self.add_worker(url, client=OpenAI( - api_key=credentials["api_key"], - base_url= f"{url}/{self.model_name}" + "/v1", - default_headers=self.get_default_headers(), - )) - else: - raise RuntimeError(f"worker_url ({url}/{self.model_name}) initialization failed: {init_result}") + self.add_worker(url) #def init_server_and_add_to_workers_list + def add_worker(self, url: str) -> None: + from openai import OpenAI + credentials = self._prepare_credentials() + init_result = self.post_server(endpoint="init_server", + server_url=url, + data={**self.to_dict([ParamDataClass]), + **{"model_name": self.model_name}}) + if init_result == "Accepted": + client= OpenAI( + api_key=credentials["api_key"], + base_url=f"{url}/{self.model_name}" + "/v1", + default_headers=self.get_default_headers(), + ) + else: + raise RuntimeError(f"worker_url ({url}/{self.model_name}) initialization failed: {init_result}") - def add_worker(self, url: str, client) -> None: with self.lock: self.workers_state[url] = {"status": "ready", "client": client} @@ -3634,14 +3637,16 @@ def release_worker(self, url: str): self.workers_state[url]["status"] = "ready" def assign_worker(self): - with self.lock: - while True: + while True: + with self.lock: # print("trying to assign worker...") for url, rec in self.workers_state.items(): if rec["status"] == "ready": rec["status"] ="assigned" return url, rec["client"] - time.sleep(random.uniform(0, 1)) + if len(self.workers_state) == 0: + logger.warning("Not servers available for inference. Waiting for servers to be added... ") + time.sleep(10) def _prepare_credentials(self) -> CredentialsOpenAi: return {"api" + "_" + "key": "no-api-key",} @@ -3717,68 +3722,164 @@ def store_after_pack_in_cache(future, cache_key): else: self._cache[cache_key] = prediction + def cleanup(self): + pass + + def _signal_handler(self, signum, frame): + logger.info(f"Received signal {signum}, cleaning up and exiting") + self.cleanup() + sys.exit(0) + + def _exception_handler(self, exc_type, exc_value, exc_traceback): + # Don't double-handle KeyboardInterrupt + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return + + logger.error("Uncaught exception, cleaning up before exit", + exc_info=(exc_type, exc_value, exc_traceback)) + self.cleanup() + # Print the exception as usual + sys.__excepthook__(exc_type, exc_value, exc_traceback) + sys.exit(1) + + def _register_cleanup_handlers(self): + # 1) Normal exit + atexit.register(self.cleanup) + + # 2) OS signals + signal.signal(signal.SIGINT, self._signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, self._signal_handler) # kill + + # 3) Uncaught exceptions + sys.excepthook = self._exception_handler + -class CCCInferenceEngine(MultiServersInferenceEngine): +class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin): ccc_host: str ccc_user: str - ccc_path: str + ccc_temp_dir = "$XDG_CACHE_HOME" ccc_python: str - server_port: str = "5000" + num_of_workers: int = 5 - workers_url: List[str] = [] + ccc_queue: Literal["x86_24h", "nonstandard"] = "x86_24h" + ccc_gpu: Literal["v100", "a100", "a100_80gb"] = "a100_80gb" + ccc_mem: str = "120g" + ccc_num_gpus: int = 1 + + server_port: str = "5000" + + ccc_jobs: Dict[str, Literal["AVAIL", "RUN", "EXIT", "ERROR"]] = {} + + _requirements_list = { + "paramiko": "Install paramiko package using 'pip install --upgrade paramiko", + } def prepare_engine(self): assert not self.workers_url, "CCCInferenceEngine doesn't support explicit setting of workers_url" - self.start_ccc_servers() + # the super class prepare_engine() must be executed first, as the following logic relies on its work. self.prepare_engine() + self._connect() + self._submit_jobs() + self._start_monitoring_jobs() - def start_ccc_servers(self): + def _connect(self): import paramiko - ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.ccc_host, username=self.ccc_user) - ssh.exec_command(f"mkdir -p {self.ccc_path}") - - self.ccc_jobs = {} - for i in range(self.num_of_workers): - command = f"bash -l -c 'jbsub -queue x86_6h -cores 4+1 -require v100 -mem 24G -out ~/server{i}.log {self.ccc_python} /dccstor/fuse/unitxt/ccc_worker_server.py {self.server_port}'" - stdin, stdout, stderr = ssh.exec_command(command) + self.ssh = paramiko.SSHClient() + self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.ssh.connect(self.ccc_host, username=self.ccc_user) + + def _submit_jobs(self) -> None: + for _ in range(self.num_of_workers): + # TODO: We might need to use accelerate in the command for multi-gpu. See FmEval + command = (f"bash -l -c 'jbsub " + f"-queue {self.ccc_queue} " + f"-cores 4+{self.ccc_num_gpus} " + f"-require {self.ccc_gpu} " + f"-mem {self.ccc_mem} " + f"{self.ccc_python} unitxt-inference-server {self.server_port}'") + + stdin, stdout, stderr = self.ssh.exec_command(command) job_output = stdout.read().decode().strip() job_error = stderr.read().decode().strip() match = re.search(r"Job <(\d+)> is submitted", job_output) if match: job_id = match.group(1) logger.info(f"Start job ID: {job_id}") - self.ccc_jobs[job_id] ={"status": "AVAIL", "log_id": i} + self.ccc_jobs[job_id] = "AVAIL" else: - raise RuntimeError(f"Failed to run jbsub on host {self.ccc_host}.\nstdout: {job_output}.\nstderr: {job_error}") - - def run_monitor_ccc_jobs(ssh, sample_every): - while True: - command = "bash -l -c 'jbinfo'" - stdin, stdout, stderr = ssh.exec_command(command) - output = stdout.read().decode().strip() - #error = stderr.read().decode().strip() - for job_id in self.ccc_jobs.keys(): - match = re.search(rf"^{job_id}\s+\S+\s+(\w+)", output, re.MULTILINE) - if match: - status = match.group(1) - if status != self.ccc_jobs[job_id]["status"]: - if self.ccc_jobs[job_id]["status"] == "RUN": - pass # add server to server list - elif status == "RUN": - pass # remove server from server list. Consider fetching the server log. - self.ccc_jobs[job_id]["status"] = status - logger.info(f"status has been changed: {job_id} - {status}") - - - time.sleep(sample_every) - - thread = threading.Thread(target=run_monitor_ccc_jobs, args=(ssh, 10)) - thread.daemon = True # - thread.start() - - - time.sleep(200) # This keeps the main thread alive so the background thread can continue + raise RuntimeError( + f"Failed to run jbsub on host {self.ccc_host}.\nstdout: {job_output}.\nstderr: {job_error}") + + def _start_monitoring_jobs(self): + self.monitor_thread = threading.Thread(target=self._monitor_jobs, daemon=True) + self.monitor_thread.start() + + def _monitor_jobs(self): + while True: + command = "bash -l -c 'jbinfo'" + stdin, stdout, stderr = self.ssh.exec_command(command) + output = stdout.read().decode() + for job_id in self.ccc_jobs.keys(): + match = re.search(rf"^{job_id}\s+\S+\s+(\w+)", output, re.MULTILINE) + if match: + status = match.group(1) + if status != self.ccc_jobs["status"]: + logger.info(f"status has been changed: {job_id} -> {status}") + self.ccc_jobs[job_id]= status + if status == "RUN": + self._add_server_to_list(job_id) + elif status == "EXIT": + pass # remove server from server list. Consider fetching the server log. Maybe rerun it? What happens when the 24h are up? + elif status == "DONE": + pass # remove server from server list. Consider fetching the server log. + time.sleep(60) + + + def _add_server_to_list(self, job_id): + assert self.ccc_jobs[job_id] == "RUN" + try_fetch_server_url_tries = 2 + + while try_fetch_server_url_tries > 0: + stdout, stderr = self._fetch_job_logs(job_id) + + ip_match = re.search(r"server_ip=([0-9.]+)", stdout) + port_match = re.search(r"server_port=(\d+)", stdout) + + if ip_match and port_match: + ip_address = ip_match.group(1) + port = port_match.group(1) + self.add_worker(f"http://{ip_address}:{port}") + else: + time.sleep(30) + try_fetch_server_url_tries =- 1 + if try_fetch_server_url_tries == 0: + raise RuntimeError(f"Error building server on job {job_id}." + f"stdout: {stdout}." + f"stderr:{stderr}") + + def _fetch_job_logs(self, job_id: str) -> Tuple[str, str]: + stdout_fpath, stderr_fpath = self._get_job_log_files_paths(job_id) + sftp = self.ssh.open_sftp() + try: + with sftp.open(stdout_fpath, "r") as f: + stdout = f.read().decode().strip() + with sftp.open(stderr_fpath, "r") as f: + stderr = f.read().decode().strip() + finally: + sftp.close() + + return stdout, stderr + + @staticmethod + def _get_job_log_files_paths(job_id: str) -> Tuple[str, str]: + default_dir = "~/.lsf/dcc/" + return f"{default_dir}/{job_id}.stdout", f"{default_dir}/{job_id}.stderr" + + def cleanup(self): + for job_id in self.ccc_jobs.keys(): + logger.info(f"Killing job {job_id}") + command = f"bash -l -c 'jbadmin -kill {job_id}'" + self.ssh.exec_command(command) From ddc7c7df917552efb1820ddfb7f3e9b732972088 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:12:36 +0300 Subject: [PATCH 11/28] update --- src/unitxt/inference.py | 83 ++++++++++++++++++++++------------------- test_caching_ccc.py | 6 +-- 2 files changed, 48 insertions(+), 41 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index f12dfad665..21033cdbcd 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3595,6 +3595,7 @@ def _infer( class MultiServersInferenceEngine(OpenAiInferenceEngine): workers_url: List[str] = [] + num_of_workers: int @staticmethod def post_server(server_url: str, endpoint:str , data: Dict) -> str: @@ -3604,10 +3605,13 @@ def post_server(server_url: str, endpoint:str , data: Dict) -> str: return response.json() def prepare_engine(self): + assert self.num_of_workers > 0 self._register_cleanup_handlers() self.lock = threading.Lock() self.workers_state = {} + self._executor = ThreadPoolExecutor(max_workers=self.num_of_workers) + for url in self.workers_url: self.add_worker(url) @@ -3656,11 +3660,6 @@ def _infer( dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool = False, ) -> List[Any]: # Now returns a Future object - """Runs inference in parallel, returning futures for each batch.""" - # Lazy-initialize executor if not already created - if not hasattr(self, "_executor"): - self._executor = ThreadPoolExecutor(max_workers=len(self.workers_state)) - # Submit the batch job batch_future = self._executor.submit(self._run_batch, dataset, return_meta_data) @@ -3728,7 +3727,7 @@ def cleanup(self): def _signal_handler(self, signum, frame): logger.info(f"Received signal {signum}, cleaning up and exiting") self.cleanup() - sys.exit(0) + os._exit(0) def _exception_handler(self, exc_type, exc_value, exc_traceback): # Don't double-handle KeyboardInterrupt @@ -3741,7 +3740,7 @@ def _exception_handler(self, exc_type, exc_value, exc_traceback): self.cleanup() # Print the exception as usual sys.__excepthook__(exc_type, exc_value, exc_traceback) - sys.exit(1) + os._exit(0) def _register_cleanup_handlers(self): # 1) Normal exit @@ -3755,7 +3754,7 @@ def _register_cleanup_handlers(self): sys.excepthook = self._exception_handler -class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin): +class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin, HFGenerationParamsMixin): ccc_host: str ccc_user: str ccc_temp_dir = "$XDG_CACHE_HOME" @@ -3778,7 +3777,7 @@ class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin): def prepare_engine(self): assert not self.workers_url, "CCCInferenceEngine doesn't support explicit setting of workers_url" # the super class prepare_engine() must be executed first, as the following logic relies on its work. - self.prepare_engine() + super().prepare_engine() self._connect() self._submit_jobs() self._start_monitoring_jobs() @@ -3798,7 +3797,7 @@ def _submit_jobs(self) -> None: f"-cores 4+{self.ccc_num_gpus} " f"-require {self.ccc_gpu} " f"-mem {self.ccc_mem} " - f"{self.ccc_python} unitxt-inference-server {self.server_port}'") + f"{self.ccc_python} -m unitxt.service.inference_server {self.server_port}'") stdin, stdout, stderr = self.ssh.exec_command(command) job_output = stdout.read().decode().strip() @@ -3817,24 +3816,31 @@ def _start_monitoring_jobs(self): self.monitor_thread.start() def _monitor_jobs(self): - while True: - command = "bash -l -c 'jbinfo'" - stdin, stdout, stderr = self.ssh.exec_command(command) - output = stdout.read().decode() - for job_id in self.ccc_jobs.keys(): - match = re.search(rf"^{job_id}\s+\S+\s+(\w+)", output, re.MULTILINE) - if match: - status = match.group(1) - if status != self.ccc_jobs["status"]: - logger.info(f"status has been changed: {job_id} -> {status}") - self.ccc_jobs[job_id]= status - if status == "RUN": - self._add_server_to_list(job_id) - elif status == "EXIT": - pass # remove server from server list. Consider fetching the server log. Maybe rerun it? What happens when the 24h are up? - elif status == "DONE": - pass # remove server from server list. Consider fetching the server log. - time.sleep(60) + try: + while True: + command = "bash -l -c 'jbinfo'" + stdin, stdout, stderr = self.ssh.exec_command(command) + output = stdout.read().decode() + for job_id in self.ccc_jobs.keys(): + match = re.search(rf"^{job_id}\s+\S+\s+(\w+)", output, re.MULTILINE) + if match: + status = match.group(1) + if status != self.ccc_jobs[job_id]: + logger.info(f"status has been changed: {job_id} -> {status}") + self.ccc_jobs[job_id]= status + if status == "RUN": + self._add_server_to_list(job_id) + elif status == "EXIT": + time.sleep(10) + stdout, stderr = self._fetch_job_logs(job_id) + logger.error(stdout) + logger.error(stderr) + elif status == "DONE": + pass # remove server from server list. Consider fetching the server log. + time.sleep(60) + except Exception as e: + logger.exception(f"Fatal error in monitor thread, shutting down entire process. {e!s}") + os.kill(os.getpid(), signal.SIGTERM) def _add_server_to_list(self, job_id): @@ -3860,26 +3866,27 @@ def _add_server_to_list(self, job_id): f"stderr:{stderr}") def _fetch_job_logs(self, job_id: str) -> Tuple[str, str]: - stdout_fpath, stderr_fpath = self._get_job_log_files_paths(job_id) + stdout_path, stderr_path = self._get_job_log_files_paths(job_id) sftp = self.ssh.open_sftp() try: - with sftp.open(stdout_fpath, "r") as f: + with sftp.open(stdout_path, "r") as f: stdout = f.read().decode().strip() - with sftp.open(stderr_fpath, "r") as f: + with sftp.open(stderr_path, "r") as f: stderr = f.read().decode().strip() finally: sftp.close() return stdout, stderr - @staticmethod - def _get_job_log_files_paths(job_id: str) -> Tuple[str, str]: - default_dir = "~/.lsf/dcc/" + def _get_job_log_files_paths(self, job_id: str) -> Tuple[str, str]: + stdin, stdout, stderr = self.ssh.exec_command("echo $HOME") + remote_home = stdout.read().decode().strip() + default_dir = f"{remote_home}/.lsf/cccCluster" return f"{default_dir}/{job_id}.stdout", f"{default_dir}/{job_id}.stderr" def cleanup(self): - for job_id in self.ccc_jobs.keys(): - logger.info(f"Killing job {job_id}") - command = f"bash -l -c 'jbadmin -kill {job_id}'" - self.ssh.exec_command(command) + logger.info(f"Killing job {self.ccc_jobs.keys()}") + command = f"bash -l -c 'jbadmin -kill {' '.join(self.ccc_jobs.keys())}'" + logger.info(command) + self.ssh.exec_command(command) diff --git a/test_caching_ccc.py b/test_caching_ccc.py index c77710b676..6d90f2e226 100644 --- a/test_caching_ccc.py +++ b/test_caching_ccc.py @@ -49,10 +49,10 @@ def load_dataset_cached(**kwargs): use_cache=True, cache_batch_size=5, ccc_host="cccxl013.pok.ibm.com", - ccc_user="eladv", - ccc_path="", - ccc_python="/dccstor/fuse/eladv_envs/unitxt/bin/python", + ccc_user="ofirarviv", + ccc_python="/dccstor/fme/users/ofir.arviv/miniforge3/envs/fme/bin/python", num_of_workers=3, + ccc_queue = "nonstandard" ) start_time = time.time() From 18ea4594e2a8eaaf5c365d45de2c90cb5c85a85c Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:13:33 +0300 Subject: [PATCH 12/28] update --- src/unitxt/service/inference_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py index 62734248ad..a392806ef3 100644 --- a/src/unitxt/service/inference_server.py +++ b/src/unitxt/service/inference_server.py @@ -128,6 +128,7 @@ def status(): parser = argparse.ArgumentParser(prog="unitxt inference worker server") parser.add_argument("port", type=int, help="Port to run the server on", default=8080) args = parser.parse_args() + PORT = args.port app.run(host="0.0.0.0", port=args.port, debug=True) From fd0bd00aa346acfd64db27d4b22dfbf1e5965773 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:15:20 +0300 Subject: [PATCH 13/28] update --- src/unitxt/service/inference_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py index a392806ef3..7a75503898 100644 --- a/src/unitxt/service/inference_server.py +++ b/src/unitxt/service/inference_server.py @@ -126,7 +126,7 @@ def status(): if __name__ == "__main__": parser = argparse.ArgumentParser(prog="unitxt inference worker server") - parser.add_argument("port", type=int, help="Port to run the server on", default=8080) + parser.add_argument("port", type=int, help="Port to run the server on", default=8080, required=False) args = parser.parse_args() PORT = args.port app.run(host="0.0.0.0", port=args.port, debug=True) From aa7136912b37d14fd01f3939ece740cfcddc56ba Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:15:35 +0300 Subject: [PATCH 14/28] update --- src/unitxt/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 21033cdbcd..d0858a72fb 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3797,7 +3797,7 @@ def _submit_jobs(self) -> None: f"-cores 4+{self.ccc_num_gpus} " f"-require {self.ccc_gpu} " f"-mem {self.ccc_mem} " - f"{self.ccc_python} -m unitxt.service.inference_server {self.server_port}'") + f"{self.ccc_python} -m unitxt.service.inference_server --port {self.server_port}'") stdin, stdout, stderr = self.ssh.exec_command(command) job_output = stdout.read().decode().strip() From 9e4bb86cd4ee4aaefbc2e8ce4fd2f3f320317c63 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:15:55 +0300 Subject: [PATCH 15/28] update --- src/unitxt/service/inference_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py index 7a75503898..788eede1e6 100644 --- a/src/unitxt/service/inference_server.py +++ b/src/unitxt/service/inference_server.py @@ -126,7 +126,7 @@ def status(): if __name__ == "__main__": parser = argparse.ArgumentParser(prog="unitxt inference worker server") - parser.add_argument("port", type=int, help="Port to run the server on", default=8080, required=False) + parser.add_argument("--port", type=int, help="Port to run the server on", default=8080, required=False) args = parser.parse_args() PORT = args.port app.run(host="0.0.0.0", port=args.port, debug=True) From 47aaf85bd523f336c76e1eb8316c126e7c2b602f Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:31:43 +0300 Subject: [PATCH 16/28] update --- src/unitxt/inference.py | 21 +++++++++++++-------- src/unitxt/service/inference_server.py | 8 ++++---- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index d0858a72fb..1702635cdb 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3630,6 +3630,7 @@ def add_worker(self, url: str) -> None: base_url=f"{url}/{self.model_name}" + "/v1", default_headers=self.get_default_headers(), ) + logger.info(f"Adding server {url}") else: raise RuntimeError(f"worker_url ({url}/{self.model_name}) initialization failed: {init_result}") @@ -3766,7 +3767,7 @@ class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin, ccc_mem: str = "120g" ccc_num_gpus: int = 1 - server_port: str = "5000" + server_port: str = "8080" ccc_jobs: Dict[str, Literal["AVAIL", "RUN", "EXIT", "ERROR"]] = {} @@ -3850,20 +3851,24 @@ def _add_server_to_list(self, job_id): while try_fetch_server_url_tries > 0: stdout, stderr = self._fetch_job_logs(job_id) - ip_match = re.search(r"server_ip=([0-9.]+)", stdout) - port_match = re.search(r"server_port=(\d+)", stdout) + ip_match = re.search(r"server_ip=([0-9.]+)", stderr) + port_match = re.search( + r"server_port=(\d+)", stderr) if ip_match and port_match: ip_address = ip_match.group(1) port = port_match.group(1) self.add_worker(f"http://{ip_address}:{port}") else: - time.sleep(30) - try_fetch_server_url_tries =- 1 + try_fetch_server_url_tries -= 1 + logger.warning(f"Did not manage to get server url. Trying {try_fetch_server_url_tries} more times...") if try_fetch_server_url_tries == 0: - raise RuntimeError(f"Error building server on job {job_id}." - f"stdout: {stdout}." - f"stderr:{stderr}") + logger.exception(f"Error building server on job {job_id}.\n" + f"stdout: {stdout}.\n" + f"stderr:{stderr}") + os.kill(os.getpid(), signal.SIGTERM) + time.sleep(30) + def _fetch_job_logs(self, job_id: str) -> Tuple[str, str]: stdout_path, stderr_path = self._get_job_log_files_paths(job_id) diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py index 788eede1e6..e2b467dae6 100644 --- a/src/unitxt/service/inference_server.py +++ b/src/unitxt/service/inference_server.py @@ -25,10 +25,6 @@ def __init__(self): self.shutdown_flag = False self.monitor_thread.start() - hostname = socket.gethostname() - ip_address = socket.gethostbyname(hostname) - app.logger.info(f"server_ip={ip_address} server_port={PORT}") - def update_last_request_time(self): self.last_request_time = time.time() @@ -129,6 +125,10 @@ def status(): parser.add_argument("--port", type=int, help="Port to run the server on", default=8080, required=False) args = parser.parse_args() PORT = args.port + + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + logging.info(f"server_ip={ip_address} server_port={PORT}") app.run(host="0.0.0.0", port=args.port, debug=True) From 17dd2e76e09f9dcb2cef2c9772b061d1ff186150 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:40:01 +0300 Subject: [PATCH 17/28] update --- src/unitxt/inference.py | 4 ++-- src/unitxt/service/inference_server.py | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 1702635cdb..c705882d4e 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3592,7 +3592,7 @@ def _infer( ParamDataClass = TypeVar("ParamDataClass") -class MultiServersInferenceEngine(OpenAiInferenceEngine): +class MultiServersInferenceEngine(OpenAiInferenceEngine, HFGenerationParamsMixin): workers_url: List[str] = [] num_of_workers: int @@ -3622,7 +3622,7 @@ def add_worker(self, url: str) -> None: credentials = self._prepare_credentials() init_result = self.post_server(endpoint="init_server", server_url=url, - data={**self.to_dict([ParamDataClass]), + data={**self.to_dict([HFGenerationParamsMixin]), **{"model_name": self.model_name}}) if init_result == "Accepted": client= OpenAI( diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py index e2b467dae6..16b0bfe3e0 100644 --- a/src/unitxt/service/inference_server.py +++ b/src/unitxt/service/inference_server.py @@ -25,6 +25,10 @@ def __init__(self): self.shutdown_flag = False self.monitor_thread.start() + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + logging.info(f"server_ip={ip_address} server_port={PORT}") + def update_last_request_time(self): self.last_request_time = time.time() @@ -124,11 +128,6 @@ def status(): parser = argparse.ArgumentParser(prog="unitxt inference worker server") parser.add_argument("--port", type=int, help="Port to run the server on", default=8080, required=False) args = parser.parse_args() - PORT = args.port - - hostname = socket.gethostname() - ip_address = socket.gethostbyname(hostname) - logging.info(f"server_ip={ip_address} server_port={PORT}") app.run(host="0.0.0.0", port=args.port, debug=True) From fba027487e8273732a9eed61bb0305c447904817 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:48:45 +0300 Subject: [PATCH 18/28] update --- src/unitxt/service/inference_server.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py index 16b0bfe3e0..a8e1c289c3 100644 --- a/src/unitxt/service/inference_server.py +++ b/src/unitxt/service/inference_server.py @@ -14,20 +14,20 @@ logging.basicConfig(level=logging.INFO) app = Flask(__name__) -PORT = None class Server: - def __init__(self): + def __init__(self, port: int): self.inference_engine = None self.inactivity_timeout = 600 self.monitor_thread = threading.Thread(target=self.monitor_activity, daemon=True) self.last_request_time = time.time() self.shutdown_flag = False self.monitor_thread.start() + self.port = port hostname = socket.gethostname() ip_address = socket.gethostbyname(hostname) - logging.info(f"server_ip={ip_address} server_port={PORT}") + logging.info(f"server_ip={ip_address} server_port={self.port}") def update_last_request_time(self): self.last_request_time = time.time() @@ -38,7 +38,7 @@ def monitor_activity(self): if time.time() - self.last_request_time > self.inactivity_timeout: app.logger.info(f"No requests for {self.inactivity_timeout} seconds. Shutting down server...") try: - requests.post(f"http://localhost:{PORT}/shutdown", timeout=5) + requests.post(f"http://localhost:{self.port}/shutdown", timeout=5) except Exception: pass else: @@ -65,8 +65,6 @@ def infer(self, **kwargs): return self.inference_engine(inputs) -server = Server() - @app.before_request def update_activity(): server.update_last_request_time() @@ -128,6 +126,7 @@ def status(): parser = argparse.ArgumentParser(prog="unitxt inference worker server") parser.add_argument("--port", type=int, help="Port to run the server on", default=8080, required=False) args = parser.parse_args() + server = Server(args.port) app.run(host="0.0.0.0", port=args.port, debug=True) From c0ef23c93fde67e97b915fd74ad0ed64877bceaf Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:58:07 +0300 Subject: [PATCH 19/28] update --- src/unitxt/service/inference_server.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py index a8e1c289c3..9d326d6f78 100644 --- a/src/unitxt/service/inference_server.py +++ b/src/unitxt/service/inference_server.py @@ -8,6 +8,7 @@ import requests from flask import Flask, jsonify, request +from werkzeug.serving import make_server from ..inference import HFPipelineBasedInferenceEngine @@ -127,7 +128,13 @@ def status(): parser.add_argument("--port", type=int, help="Port to run the server on", default=8080, required=False) args = parser.parse_args() server = Server(args.port) - app.run(host="0.0.0.0", port=args.port, debug=True) - + srv = make_server("0.0.0.0", args.port, app) + # only here after bind succeeded + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + logging.info(f"server_ip={ip_address} server_port={args.port}") + + # this actually starts the Werkzeug loop (blocking) + srv.serve_forever() From fd3ab22e933a86fb597e17a1ccaa6f1fc75cd226 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 14:59:40 +0300 Subject: [PATCH 20/28] update --- src/unitxt/service/inference_server.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py index 9d326d6f78..1c0df6fede 100644 --- a/src/unitxt/service/inference_server.py +++ b/src/unitxt/service/inference_server.py @@ -26,10 +26,6 @@ def __init__(self, port: int): self.monitor_thread.start() self.port = port - hostname = socket.gethostname() - ip_address = socket.gethostbyname(hostname) - logging.info(f"server_ip={ip_address} server_port={self.port}") - def update_last_request_time(self): self.last_request_time = time.time() From 8de97d4d9e45605130a08aab944b88b48a329823 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 15:08:04 +0300 Subject: [PATCH 21/28] update --- src/unitxt/service/inference_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py index 1c0df6fede..d76f04294c 100644 --- a/src/unitxt/service/inference_server.py +++ b/src/unitxt/service/inference_server.py @@ -124,7 +124,7 @@ def status(): parser.add_argument("--port", type=int, help="Port to run the server on", default=8080, required=False) args = parser.parse_args() server = Server(args.port) - srv = make_server("0.0.0.0", args.port, app) + srv = make_server("0.0.0.0", args.port, app, threaded=True) # only here after bind succeeded hostname = socket.gethostname() ip_address = socket.gethostbyname(hostname) From 0d7bd46084a38dd2ac1b8e41ae54f203b5984532 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 15:16:41 +0300 Subject: [PATCH 22/28] update --- src/unitxt/inference.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index c705882d4e..197e7cc9b2 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -39,6 +39,7 @@ from tqdm import tqdm, trange from tqdm.asyncio import tqdm_asyncio +from . import dataclass from .artifact import Artifact from .dataclass import InternalField, NonPositionalField from .deprecation_utils import deprecation @@ -3641,6 +3642,10 @@ def release_worker(self, url: str): with self.lock: self.workers_state[url]["status"] = "ready" + def remove_worker(self, url: str): + with self.lock: + del self.workers_state[url] + def assign_worker(self): while True: with self.lock: @@ -3754,6 +3759,10 @@ def _register_cleanup_handlers(self): # 3) Uncaught exceptions sys.excepthook = self._exception_handler +@dataclass +class CCCServerWorkerInfo: + status: Literal["AVAIL", "RUN", "EXIT", "ERROR"] = "AVAIL" + server_url: Optional[str] = None class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin, HFGenerationParamsMixin): ccc_host: str @@ -3767,9 +3776,9 @@ class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin, ccc_mem: str = "120g" ccc_num_gpus: int = 1 - server_port: str = "8080" + server_port: str = "5000" - ccc_jobs: Dict[str, Literal["AVAIL", "RUN", "EXIT", "ERROR"]] = {} + ccc_jobs: Dict[str, CCCServerWorkerInfo] = {} _requirements_list = { "paramiko": "Install paramiko package using 'pip install --upgrade paramiko", @@ -3807,7 +3816,7 @@ def _submit_jobs(self) -> None: if match: job_id = match.group(1) logger.info(f"Start job ID: {job_id}") - self.ccc_jobs[job_id] = "AVAIL" + self.ccc_jobs[job_id] = CCCServerWorkerInfo() else: raise RuntimeError( f"Failed to run jbsub on host {self.ccc_host}.\nstdout: {job_output}.\nstderr: {job_error}") @@ -3826,9 +3835,9 @@ def _monitor_jobs(self): match = re.search(rf"^{job_id}\s+\S+\s+(\w+)", output, re.MULTILINE) if match: status = match.group(1) - if status != self.ccc_jobs[job_id]: + if status != self.ccc_jobs[job_id].status: logger.info(f"status has been changed: {job_id} -> {status}") - self.ccc_jobs[job_id]= status + self.ccc_jobs[job_id].status = status if status == "RUN": self._add_server_to_list(job_id) elif status == "EXIT": @@ -3836,8 +3845,9 @@ def _monitor_jobs(self): stdout, stderr = self._fetch_job_logs(job_id) logger.error(stdout) logger.error(stderr) + self.remove_worker(self.ccc_jobs[job_id].server_url) elif status == "DONE": - pass # remove server from server list. Consider fetching the server log. + self.remove_worker(self.ccc_jobs[job_id].server_url) time.sleep(60) except Exception as e: logger.exception(f"Fatal error in monitor thread, shutting down entire process. {e!s}") From 07104955a471cc40ad25916c166e748069ee85ed Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 15:27:29 +0300 Subject: [PATCH 23/28] update --- src/unitxt/inference.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 197e7cc9b2..6b26d28c5b 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -39,7 +39,6 @@ from tqdm import tqdm, trange from tqdm.asyncio import tqdm_asyncio -from . import dataclass from .artifact import Artifact from .dataclass import InternalField, NonPositionalField from .deprecation_utils import deprecation @@ -3759,8 +3758,7 @@ def _register_cleanup_handlers(self): # 3) Uncaught exceptions sys.excepthook = self._exception_handler -@dataclass -class CCCServerWorkerInfo: +class CCCServerWorkerInfo(Artifact): status: Literal["AVAIL", "RUN", "EXIT", "ERROR"] = "AVAIL" server_url: Optional[str] = None From 7446daebf32aebae424c37ec396fd5a6fc1396e4 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 15:32:32 +0300 Subject: [PATCH 24/28] update --- src/unitxt/inference.py | 4 +++- test_caching_ccc.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 6b26d28c5b..447ed13f87 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3853,7 +3853,7 @@ def _monitor_jobs(self): def _add_server_to_list(self, job_id): - assert self.ccc_jobs[job_id] == "RUN" + assert self.ccc_jobs[job_id].status == "RUN" try_fetch_server_url_tries = 2 while try_fetch_server_url_tries > 0: @@ -3903,3 +3903,5 @@ def cleanup(self): logger.info(command) self.ssh.exec_command(command) + # TODO: Error with the cleanup in ssh when all the inputs are in cache + diff --git a/test_caching_ccc.py b/test_caching_ccc.py index 6d90f2e226..d3dfcff7b6 100644 --- a/test_caching_ccc.py +++ b/test_caching_ccc.py @@ -40,7 +40,7 @@ def load_dataset_cached(**kwargs): dataset = load_dataset_cached(card="cards.openbook_qa", split="test") - dataset = dataset.select(range(100)) + dataset = dataset.select(range(200)) inference_model = CCCInferenceEngine( model_name="google/flan-t5-small", From ffb7726bdc4e91b6aaf46d015bf2f44b4bba667a Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Tue, 29 Apr 2025 15:48:25 +0300 Subject: [PATCH 25/28] update --- src/unitxt/inference.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 447ed13f87..eae5d14cac 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3898,10 +3898,15 @@ def _get_job_log_files_paths(self, job_id: str) -> Tuple[str, str]: return f"{default_dir}/{job_id}.stdout", f"{default_dir}/{job_id}.stderr" def cleanup(self): + transport = self.ssh.get_transport() if self.ssh else None + if transport is None or not transport.is_active(): + # re-open connection + self._connect() logger.info(f"Killing job {self.ccc_jobs.keys()}") command = f"bash -l -c 'jbadmin -kill {' '.join(self.ccc_jobs.keys())}'" logger.info(command) self.ssh.exec_command(command) - # TODO: Error with the cleanup in ssh when all the inputs are in cache + # TODO: Error with the cleanup in ssh when all the inputs are in cache. + # It has to do with the ssh not managing to finish its operation or something like this before cleanup From dac1ada8d50240940dbd7c29a2484dd33081d134 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Wed, 30 Apr 2025 12:44:40 +0300 Subject: [PATCH 26/28] update --- src/unitxt/inference.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index eae5d14cac..5c81f82fdf 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3777,6 +3777,7 @@ class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin, server_port: str = "5000" ccc_jobs: Dict[str, CCCServerWorkerInfo] = {} + _monitor_jobs: bool = False _requirements_list = { "paramiko": "Install paramiko package using 'pip install --upgrade paramiko", @@ -3821,11 +3822,12 @@ def _submit_jobs(self) -> None: def _start_monitoring_jobs(self): self.monitor_thread = threading.Thread(target=self._monitor_jobs, daemon=True) + self._monitor_jobs = True self.monitor_thread.start() def _monitor_jobs(self): try: - while True: + while self._monitor_jobs: command = "bash -l -c 'jbinfo'" stdin, stdout, stderr = self.ssh.exec_command(command) output = stdout.read().decode() @@ -3898,6 +3900,12 @@ def _get_job_log_files_paths(self, job_id: str) -> Tuple[str, str]: return f"{default_dir}/{job_id}.stdout", f"{default_dir}/{job_id}.stderr" def cleanup(self): + self._monitor_jobs = False + """ + TODO: Error with the cleanup in ssh when all the inputs are in cache. + It has to do with the ssh not managing to finish its operation or something like this before cleanup + The real solution is to do the cleanup after them main infer() + """ transport = self.ssh.get_transport() if self.ssh else None if transport is None or not transport.is_active(): # re-open connection @@ -3906,7 +3914,7 @@ def cleanup(self): command = f"bash -l -c 'jbadmin -kill {' '.join(self.ccc_jobs.keys())}'" logger.info(command) self.ssh.exec_command(command) + self.ssh.close() + - # TODO: Error with the cleanup in ssh when all the inputs are in cache. - # It has to do with the ssh not managing to finish its operation or something like this before cleanup From 6b5383f49beb454e94acdd53d35d687ce4e2af75 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Wed, 30 Apr 2025 12:46:00 +0300 Subject: [PATCH 27/28] update --- src/unitxt/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 5c81f82fdf..817d98ba1d 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3851,6 +3851,7 @@ def _monitor_jobs(self): time.sleep(60) except Exception as e: logger.exception(f"Fatal error in monitor thread, shutting down entire process. {e!s}") + # is there a better way to do it? os.kill(os.getpid(), signal.SIGTERM) From be5c327af69e35a1d44ff37990889accb8ee7f18 Mon Sep 17 00:00:00 2001 From: ofirarviv Date: Mon, 5 May 2025 13:55:16 +0300 Subject: [PATCH 28/28] update --- src/unitxt/inference.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 817d98ba1d..93a711105f 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3772,12 +3772,12 @@ class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin, ccc_queue: Literal["x86_24h", "nonstandard"] = "x86_24h" ccc_gpu: Literal["v100", "a100", "a100_80gb"] = "a100_80gb" ccc_mem: str = "120g" - ccc_num_gpus: int = 1 + ccc_num_gpus: int = 2 server_port: str = "5000" ccc_jobs: Dict[str, CCCServerWorkerInfo] = {} - _monitor_jobs: bool = False + _start_monitor_jobs: bool = False _requirements_list = { "paramiko": "Install paramiko package using 'pip install --upgrade paramiko", @@ -3798,6 +3798,21 @@ def _connect(self): self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.ssh.connect(self.ccc_host, username=self.ccc_user) + def _send_cmd_via_ssh(self, command: str): + """Error with the cleanup in ssh when all the inputs are in cache. + + It has to do with the ssh not managing to finish its operation or something like this before cleanup. + The real solution is to do the cleanup after them main infer(). + """ + transport = self.ssh.get_transport() if self.ssh else None + if transport is None or not transport.is_active(): + # re-open connection + self._connect() + + stdin, stdout, stderr = self.ssh.exec_command(command) + return stdin, stdout, stderr + + def _submit_jobs(self) -> None: for _ in range(self.num_of_workers): # TODO: We might need to use accelerate in the command for multi-gpu. See FmEval @@ -3808,7 +3823,7 @@ def _submit_jobs(self) -> None: f"-mem {self.ccc_mem} " f"{self.ccc_python} -m unitxt.service.inference_server --port {self.server_port}'") - stdin, stdout, stderr = self.ssh.exec_command(command) + stdin, stdout, stderr = self._send_cmd_via_ssh(command) job_output = stdout.read().decode().strip() job_error = stderr.read().decode().strip() match = re.search(r"Job <(\d+)> is submitted", job_output) @@ -3822,14 +3837,14 @@ def _submit_jobs(self) -> None: def _start_monitoring_jobs(self): self.monitor_thread = threading.Thread(target=self._monitor_jobs, daemon=True) - self._monitor_jobs = True + self._start_monitor_jobs = True self.monitor_thread.start() def _monitor_jobs(self): try: - while self._monitor_jobs: + while self._start_monitor_jobs: command = "bash -l -c 'jbinfo'" - stdin, stdout, stderr = self.ssh.exec_command(command) + stdin, stdout, stderr = self._send_cmd_via_ssh(command) output = stdout.read().decode() for job_id in self.ccc_jobs.keys(): match = re.search(rf"^{job_id}\s+\S+\s+(\w+)", output, re.MULTILINE) @@ -3895,13 +3910,13 @@ def _fetch_job_logs(self, job_id: str) -> Tuple[str, str]: return stdout, stderr def _get_job_log_files_paths(self, job_id: str) -> Tuple[str, str]: - stdin, stdout, stderr = self.ssh.exec_command("echo $HOME") + stdin, stdout, stderr = self._send_cmd_via_ssh("echo $HOME") remote_home = stdout.read().decode().strip() default_dir = f"{remote_home}/.lsf/cccCluster" return f"{default_dir}/{job_id}.stdout", f"{default_dir}/{job_id}.stderr" def cleanup(self): - self._monitor_jobs = False + self._start_monitor_jobs = False """ TODO: Error with the cleanup in ssh when all the inputs are in cache. It has to do with the ssh not managing to finish its operation or something like this before cleanup @@ -3914,8 +3929,11 @@ def cleanup(self): logger.info(f"Killing job {self.ccc_jobs.keys()}") command = f"bash -l -c 'jbadmin -kill {' '.join(self.ccc_jobs.keys())}'" logger.info(command) - self.ssh.exec_command(command) + self._send_cmd_via_ssh(command) self.ssh.close() +# TODO: Add check that both unitxt version / commit are the same +# TODO: Add more info logging +