diff --git a/pyproject.toml b/pyproject.toml index 06dc369582..5ed8201a2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,7 @@ unitxt-explore = "unitxt.ui:launch" unitxt-assistant = "unitxt.assistant:launch" unitxt-metrics-service = "unitxt.service.metrics.main:start_metrics_http_service" unitxt-evaluate = "unitxt.evaluate_cli:main" +unitxt-inference-server = "unitxt.service.inference_server:main" [tool.ruff] exclude = [ diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index fae694f708..93a711105f 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1,5 +1,6 @@ import abc import asyncio +import atexit import base64 import dataclasses import hashlib @@ -8,10 +9,13 @@ import logging import os import re +import signal import sys +import threading import time import uuid from collections import Counter +from concurrent.futures import Future, ThreadPoolExecutor, wait from datetime import datetime from itertools import islice from multiprocessing.pool import ThreadPool @@ -30,6 +34,7 @@ Union, ) +import requests from datasets import Dataset, DatasetDict, Image from tqdm import tqdm, trange from tqdm.asyncio import tqdm_asyncio @@ -276,7 +281,7 @@ def infer( if prediction is None: continue cache_key = self._get_cache_key(item) - self._cache[cache_key] = prediction + self.store_in_cache(cache_key, prediction) else: inferred_results = [] # Combine cached and inferred results in original order @@ -286,6 +291,9 @@ def infer( result.extend(batch_predictions) else: result = self._infer(dataset, return_meta_data) + + result = self.post_process_results(result) + return ListWithMetadata( result, metadata={ @@ -295,6 +303,12 @@ def infer( }, ) + def store_in_cache(self, cache_key, prediction): + self._cache[cache_key] = prediction + + def post_process_results(self, result): + return result + def _mock_infer( self, dataset: Union[List[Dict[str, Any]], Dataset], @@ -1957,7 +1971,7 @@ def prepare_engine(self): @staticmethod def get_base_url_from_model_name(model_name: str): base_url_template = ( - "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/{}" + "http://localhost:5000/{}" ) return base_url_template.format( RITSInferenceEngine._get_model_name_for_endpoint(model_name) @@ -3546,10 +3560,9 @@ def _infer( dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool = False, ) -> Union[List[str], List[TextGenerationInferenceOutput]]: - if return_meta_data and not hasattr(self.engine, "get_return_object"): + if return_meta_data: raise NotImplementedError( - f"Inference engine {self.engine.__class__.__name__} does not support return_meta_data as it " - f"does not contain a 'get_return_object' method. Please set return_meta_data=False." + f"Inference engine {self.engine.__class__.__name__} does not support return_meta_data." ) inputs = [] @@ -3576,3 +3589,351 @@ def _infer( predictions.append(options_scores.most_common(1)[0][0]) return predictions + + +ParamDataClass = TypeVar("ParamDataClass") +class MultiServersInferenceEngine(OpenAiInferenceEngine, HFGenerationParamsMixin): + + workers_url: List[str] = [] + num_of_workers: int + + @staticmethod + def post_server(server_url: str, endpoint:str , data: Dict) -> str: + headers = {"Content-Type": "application/json"} + response = requests.post(url=f"{server_url}/{endpoint}", json=data, headers=headers) + response.raise_for_status() + return response.json() + + def prepare_engine(self): + assert self.num_of_workers > 0 + self._register_cleanup_handlers() + self.lock = threading.Lock() + self.workers_state = {} + + self._executor = ThreadPoolExecutor(max_workers=self.num_of_workers) + + for url in self.workers_url: + self.add_worker(url) + + #def init_server_and_add_to_workers_list + + def add_worker(self, url: str) -> None: + from openai import OpenAI + credentials = self._prepare_credentials() + init_result = self.post_server(endpoint="init_server", + server_url=url, + data={**self.to_dict([HFGenerationParamsMixin]), + **{"model_name": self.model_name}}) + if init_result == "Accepted": + client= OpenAI( + api_key=credentials["api_key"], + base_url=f"{url}/{self.model_name}" + "/v1", + default_headers=self.get_default_headers(), + ) + logger.info(f"Adding server {url}") + else: + raise RuntimeError(f"worker_url ({url}/{self.model_name}) initialization failed: {init_result}") + + with self.lock: + self.workers_state[url] = {"status": "ready", "client": client} + + def release_worker(self, url: str): + with self.lock: + self.workers_state[url]["status"] = "ready" + + def remove_worker(self, url: str): + with self.lock: + del self.workers_state[url] + + def assign_worker(self): + while True: + with self.lock: + # print("trying to assign worker...") + for url, rec in self.workers_state.items(): + if rec["status"] == "ready": + rec["status"] ="assigned" + return url, rec["client"] + if len(self.workers_state) == 0: + logger.warning("Not servers available for inference. Waiting for servers to be added... ") + time.sleep(10) + + def _prepare_credentials(self) -> CredentialsOpenAi: + return {"api" + "_" + "key": "no-api-key",} + + def _infer( + self, + dataset: Union[List[Dict[str, Any]], Dataset], + return_meta_data: bool = False, + ) -> List[Any]: # Now returns a Future object + # Submit the batch job + batch_future = self._executor.submit(self._run_batch, dataset, return_meta_data) + + # Create individual futures that resolve when batch_future is done + element_futures = [Future() for _ in dataset] + + def set_results(batch_fut: Future): + """Callback to set individual results once batch computation is done.""" + try: + results = batch_fut.result() # Get the batch results + for i, res in enumerate(results): + element_futures[i].set_result(res) # Set each individual future + except Exception as e: + for f in element_futures: + f.set_exception(e) # Propagate any exception + + # Attach the callback to the batch future + batch_future.add_done_callback(set_results) + + return element_futures # Return a list of futures + + def _run_batch(self, batch, return_meta_data): + """Helper function to process a batch inside a thread.""" + logger.info(f"Trying to get assigned: {self.workers_state}") + url, client = self.assign_worker() + logger.info(f"Thread {url} processing batch: {self.workers_state}") + messages = [self.to_messages(instance) for instance in batch] + logger.info(f"a {url}") + try: + response = client.chat.completions.create( + messages=messages, + model=self.model_name, + **self._get_completion_kwargs(), + ) + logger.info(f"response: {response}") + predictions = [r.message.content for r in response.choices] + result = [self.get_return_object(p, response, return_meta_data) for p in predictions] + finally: + logger.info(f"Thread {url} release state:") + self.release_worker(url) + logger.info(f"Thread {url} release state done: {self.workers_state}") + return result + + def post_process_results(self, result): + futures = [r for r in result if isinstance(r, Future)] + if futures: + wait(futures) + + return [r.result() if isinstance(r, Future) else r for r in result] + + def store_in_cache(self, cache_key, prediction): + if isinstance(prediction, Future): + def store_after_pack_in_cache(future, cache_key): + prediction = future.result() + if prediction is not None: + self._cache[cache_key] = prediction + + prediction.add_done_callback(lambda f, key=cache_key: store_after_pack_in_cache(f, key)) + else: + self._cache[cache_key] = prediction + + def cleanup(self): + pass + + def _signal_handler(self, signum, frame): + logger.info(f"Received signal {signum}, cleaning up and exiting") + self.cleanup() + os._exit(0) + + def _exception_handler(self, exc_type, exc_value, exc_traceback): + # Don't double-handle KeyboardInterrupt + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return + + logger.error("Uncaught exception, cleaning up before exit", + exc_info=(exc_type, exc_value, exc_traceback)) + self.cleanup() + # Print the exception as usual + sys.__excepthook__(exc_type, exc_value, exc_traceback) + os._exit(0) + + def _register_cleanup_handlers(self): + # 1) Normal exit + atexit.register(self.cleanup) + + # 2) OS signals + signal.signal(signal.SIGINT, self._signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, self._signal_handler) # kill + + # 3) Uncaught exceptions + sys.excepthook = self._exception_handler + +class CCCServerWorkerInfo(Artifact): + status: Literal["AVAIL", "RUN", "EXIT", "ERROR"] = "AVAIL" + server_url: Optional[str] = None + +class CCCInferenceEngine(MultiServersInferenceEngine, PackageRequirementsMixin, HFGenerationParamsMixin): + ccc_host: str + ccc_user: str + ccc_temp_dir = "$XDG_CACHE_HOME" + ccc_python: str + + num_of_workers: int = 5 + ccc_queue: Literal["x86_24h", "nonstandard"] = "x86_24h" + ccc_gpu: Literal["v100", "a100", "a100_80gb"] = "a100_80gb" + ccc_mem: str = "120g" + ccc_num_gpus: int = 2 + + server_port: str = "5000" + + ccc_jobs: Dict[str, CCCServerWorkerInfo] = {} + _start_monitor_jobs: bool = False + + _requirements_list = { + "paramiko": "Install paramiko package using 'pip install --upgrade paramiko", + } + + def prepare_engine(self): + assert not self.workers_url, "CCCInferenceEngine doesn't support explicit setting of workers_url" + # the super class prepare_engine() must be executed first, as the following logic relies on its work. + super().prepare_engine() + self._connect() + self._submit_jobs() + self._start_monitoring_jobs() + + + def _connect(self): + import paramiko + self.ssh = paramiko.SSHClient() + self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.ssh.connect(self.ccc_host, username=self.ccc_user) + + def _send_cmd_via_ssh(self, command: str): + """Error with the cleanup in ssh when all the inputs are in cache. + + It has to do with the ssh not managing to finish its operation or something like this before cleanup. + The real solution is to do the cleanup after them main infer(). + """ + transport = self.ssh.get_transport() if self.ssh else None + if transport is None or not transport.is_active(): + # re-open connection + self._connect() + + stdin, stdout, stderr = self.ssh.exec_command(command) + return stdin, stdout, stderr + + + def _submit_jobs(self) -> None: + for _ in range(self.num_of_workers): + # TODO: We might need to use accelerate in the command for multi-gpu. See FmEval + command = (f"bash -l -c 'jbsub " + f"-queue {self.ccc_queue} " + f"-cores 4+{self.ccc_num_gpus} " + f"-require {self.ccc_gpu} " + f"-mem {self.ccc_mem} " + f"{self.ccc_python} -m unitxt.service.inference_server --port {self.server_port}'") + + stdin, stdout, stderr = self._send_cmd_via_ssh(command) + job_output = stdout.read().decode().strip() + job_error = stderr.read().decode().strip() + match = re.search(r"Job <(\d+)> is submitted", job_output) + if match: + job_id = match.group(1) + logger.info(f"Start job ID: {job_id}") + self.ccc_jobs[job_id] = CCCServerWorkerInfo() + else: + raise RuntimeError( + f"Failed to run jbsub on host {self.ccc_host}.\nstdout: {job_output}.\nstderr: {job_error}") + + def _start_monitoring_jobs(self): + self.monitor_thread = threading.Thread(target=self._monitor_jobs, daemon=True) + self._start_monitor_jobs = True + self.monitor_thread.start() + + def _monitor_jobs(self): + try: + while self._start_monitor_jobs: + command = "bash -l -c 'jbinfo'" + stdin, stdout, stderr = self._send_cmd_via_ssh(command) + output = stdout.read().decode() + for job_id in self.ccc_jobs.keys(): + match = re.search(rf"^{job_id}\s+\S+\s+(\w+)", output, re.MULTILINE) + if match: + status = match.group(1) + if status != self.ccc_jobs[job_id].status: + logger.info(f"status has been changed: {job_id} -> {status}") + self.ccc_jobs[job_id].status = status + if status == "RUN": + self._add_server_to_list(job_id) + elif status == "EXIT": + time.sleep(10) + stdout, stderr = self._fetch_job_logs(job_id) + logger.error(stdout) + logger.error(stderr) + self.remove_worker(self.ccc_jobs[job_id].server_url) + elif status == "DONE": + self.remove_worker(self.ccc_jobs[job_id].server_url) + time.sleep(60) + except Exception as e: + logger.exception(f"Fatal error in monitor thread, shutting down entire process. {e!s}") + # is there a better way to do it? + os.kill(os.getpid(), signal.SIGTERM) + + + def _add_server_to_list(self, job_id): + assert self.ccc_jobs[job_id].status == "RUN" + try_fetch_server_url_tries = 2 + + while try_fetch_server_url_tries > 0: + stdout, stderr = self._fetch_job_logs(job_id) + + ip_match = re.search(r"server_ip=([0-9.]+)", stderr) + port_match = re.search( + r"server_port=(\d+)", stderr) + + if ip_match and port_match: + ip_address = ip_match.group(1) + port = port_match.group(1) + self.add_worker(f"http://{ip_address}:{port}") + else: + try_fetch_server_url_tries -= 1 + logger.warning(f"Did not manage to get server url. Trying {try_fetch_server_url_tries} more times...") + if try_fetch_server_url_tries == 0: + logger.exception(f"Error building server on job {job_id}.\n" + f"stdout: {stdout}.\n" + f"stderr:{stderr}") + os.kill(os.getpid(), signal.SIGTERM) + time.sleep(30) + + + def _fetch_job_logs(self, job_id: str) -> Tuple[str, str]: + stdout_path, stderr_path = self._get_job_log_files_paths(job_id) + sftp = self.ssh.open_sftp() + try: + with sftp.open(stdout_path, "r") as f: + stdout = f.read().decode().strip() + with sftp.open(stderr_path, "r") as f: + stderr = f.read().decode().strip() + finally: + sftp.close() + + return stdout, stderr + + def _get_job_log_files_paths(self, job_id: str) -> Tuple[str, str]: + stdin, stdout, stderr = self._send_cmd_via_ssh("echo $HOME") + remote_home = stdout.read().decode().strip() + default_dir = f"{remote_home}/.lsf/cccCluster" + return f"{default_dir}/{job_id}.stdout", f"{default_dir}/{job_id}.stderr" + + def cleanup(self): + self._start_monitor_jobs = False + """ + TODO: Error with the cleanup in ssh when all the inputs are in cache. + It has to do with the ssh not managing to finish its operation or something like this before cleanup + The real solution is to do the cleanup after them main infer() + """ + transport = self.ssh.get_transport() if self.ssh else None + if transport is None or not transport.is_active(): + # re-open connection + self._connect() + logger.info(f"Killing job {self.ccc_jobs.keys()}") + command = f"bash -l -c 'jbadmin -kill {' '.join(self.ccc_jobs.keys())}'" + logger.info(command) + self._send_cmd_via_ssh(command) + self.ssh.close() + + +# TODO: Add check that both unitxt version / commit are the same +# TODO: Add more info logging + + diff --git a/src/unitxt/service/inference_server.py b/src/unitxt/service/inference_server.py new file mode 100644 index 0000000000..d76f04294c --- /dev/null +++ b/src/unitxt/service/inference_server.py @@ -0,0 +1,136 @@ +import argparse +import logging +import os +import random +import socket +import threading +import time + +import requests +from flask import Flask, jsonify, request +from werkzeug.serving import make_server + +from ..inference import HFPipelineBasedInferenceEngine + +logging.basicConfig(level=logging.INFO) + +app = Flask(__name__) + +class Server: + def __init__(self, port: int): + self.inference_engine = None + self.inactivity_timeout = 600 + self.monitor_thread = threading.Thread(target=self.monitor_activity, daemon=True) + self.last_request_time = time.time() + self.shutdown_flag = False + self.monitor_thread.start() + self.port = port + + def update_last_request_time(self): + self.last_request_time = time.time() + + def monitor_activity(self): + while not self.shutdown_flag: + time.sleep(5) + if time.time() - self.last_request_time > self.inactivity_timeout: + app.logger.info(f"No requests for {self.inactivity_timeout} seconds. Shutting down server...") + try: + requests.post(f"http://localhost:{self.port}/shutdown", timeout=5) + except Exception: + pass + else: + app.logger.info( + f"{int(self.inactivity_timeout - (time.time() - self.last_request_time))} till shutdown...") + + + def shutdown_server(self): + self.shutdown_flag = True + app.logger.info("Server shutting down...") + shutdown_func = request.environ.get("werkzeug.server.shutdown") + if shutdown_func: + shutdown_func() + # Allow the shutdown process to complete, then force exit the program + time.sleep(1) + os._exit(0) # This immediately stops the program + + def init_server(self, **kwargs): + kwargs["use_cache"] =True + self.inference_engine = HFPipelineBasedInferenceEngine(**kwargs) + + def infer(self, **kwargs): + inputs = [] + return self.inference_engine(inputs) + + +@app.before_request +def update_activity(): + server.update_last_request_time() + + +@app.route("/shutdown", methods=["POST"]) +def shutdown(): + app.logger.info("Received shutdown request") + server.shutdown_server() + return jsonify({"message": "Shutting down server..."}), 200 + + +@app.route("/init_server", methods=["POST"]) +def init_server(): + kwargs = request.get_json() + server.init_server(**kwargs) + return jsonify("Accepted") + + +@app.route("//v1/chat/completions", methods=["POST"]) +@app.route("///v1/chat/completions", methods=["POST"]) +def completions(model: str, model_prefix: str = "None"): + if random.random() < 0: + logging.error("Bad luck! Returning 500 with an error message.") + app.logger.info("Server shutting down...") + shutdown_func = request.environ.get("werkzeug.server.shutdown") + if shutdown_func: + shutdown_func() + # Allow the shutdown process to complete, then force exit the program + time.sleep(1) + os._exit(0) # This immediately stops the program + return jsonify({"error": "Bad luck, something went wrong!"}), 500 + + body = request.get_json() + # validate that request parameters are equal to the model config. Print warnings if not. + for k, v in body.items(): + if k == "messages": + continue + k = "model_name" if k == "model" else k + attr = getattr(server.inference_engine, k, None) + if attr is None: + logging.warning(f"Warning: {k} is not an attribute in inference_engine") + else: + if attr != v: + logging.warning(f"Warning: {k} value in boody({v}) is different from value in inference engine ({attr})") + texts = [{"source": m[0]["content"]} for m in body["messages"]] + predictions = server.inference_engine(texts) + return jsonify({ + "choices": [{"message": {"role": "assistant","content": p}} for p in predictions], + }) + + +@app.route("/status", methods=["GET"]) +def status(): + return "up", 200 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog="unitxt inference worker server") + parser.add_argument("--port", type=int, help="Port to run the server on", default=8080, required=False) + args = parser.parse_args() + server = Server(args.port) + srv = make_server("0.0.0.0", args.port, app, threaded=True) + # only here after bind succeeded + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + logging.info(f"server_ip={ip_address} server_port={args.port}") + + # this actually starts the Werkzeug loop (blocking) + srv.serve_forever() + + diff --git a/test_caching_ccc.py b/test_caching_ccc.py new file mode 100644 index 0000000000..d3dfcff7b6 --- /dev/null +++ b/test_caching_ccc.py @@ -0,0 +1,64 @@ +import hashlib +import json +import logging +import os +import time + +import joblib +import unitxt +from unitxt import load_dataset +from unitxt.inference import CCCInferenceEngine +from unitxt.logging_utils import set_verbosity + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + +logger = logging.getLogger(__name__) + + +def get_cache_filename(cache_dir="cache", **kwargs): + """Generate a unique filename for caching based on function arguments.""" + os.makedirs(cache_dir, exist_ok=True) + hash_key = hashlib.md5(json.dumps(kwargs, sort_keys=True).encode()).hexdigest() + return os.path.join(cache_dir, f"dataset_{hash_key}.pkl") + + +def load_dataset_cached(**kwargs): + """Load dataset with disk caching.""" + cache_file = get_cache_filename(**kwargs) + + if os.path.exists(cache_file): + return joblib.load(cache_file) + + data = load_dataset(**kwargs) + joblib.dump(data, cache_file) + return data + + +if __name__ == "__main__": + set_verbosity("debug") + unitxt.settings.allow_unverified_code = True + dataset = load_dataset_cached(card="cards.openbook_qa", + split="test") + + dataset = dataset.select(range(200)) + + inference_model = CCCInferenceEngine( + model_name="google/flan-t5-small", + temperature=0.202, + max_new_tokens=256, + use_cache=True, + cache_batch_size=5, + ccc_host="cccxl013.pok.ibm.com", + ccc_user="ofirarviv", + ccc_python="/dccstor/fme/users/ofir.arviv/miniforge3/envs/fme/bin/python", + num_of_workers=3, + ccc_queue = "nonstandard" + ) + + start_time = time.time() + predictions = inference_model.infer(dataset) + end_time = time.time() + + logger.info(f"predictions contains {predictions.count(None)} Nones") + for p in predictions: + logger.info(f"prediction: {p}") diff --git a/test_caching_try.py b/test_caching_try.py new file mode 100644 index 0000000000..a851ad46ce --- /dev/null +++ b/test_caching_try.py @@ -0,0 +1,112 @@ +import hashlib +import json +import logging +import os +import subprocess +import time + +import joblib +import requests +import unitxt +from unitxt import load_dataset +from unitxt.inference import MultiServersInferenceEngine +from unitxt.logging_utils import set_verbosity + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + +logger = logging.getLogger(__name__) + + +def get_cache_filename(cache_dir="cache", **kwargs): + """Generate a unique filename for caching based on function arguments.""" + os.makedirs(cache_dir, exist_ok=True) + hash_key = hashlib.md5(json.dumps(kwargs, sort_keys=True).encode()).hexdigest() + return os.path.join(cache_dir, f"dataset_{hash_key}.pkl") + + +def load_dataset_cached(**kwargs): + """Load dataset with disk caching.""" + cache_file = get_cache_filename(**kwargs) + + if os.path.exists(cache_file): + return joblib.load(cache_file) + + data = load_dataset(**kwargs) + joblib.dump(data, cache_file) + return data + + +def run_worker_in_a_port(port): + kill_process_on_port(port) + process = subprocess.Popen( + ["/Users/eladv/miniforge3/envs/unitxt/bin/python", "/Users/eladv/unitxt/ccc_worker_server.py", f"{port}"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True) + logger.info(f"Started worker on port {port} with PID {process.pid}") + return process + +def kill_process_on_port(port): + try: + output = subprocess.check_output(f"lsof -ti:{port}", shell=True).decode().strip() + if output: + logger.info(f"Killing process {output} on port {port}...") + subprocess.run(f"kill -9 {output}", shell=True) + except subprocess.CalledProcessError: + pass + + +def is_up(server_url): + try: + response = requests.get(f"{server_url}/status", timeout=5) + return response.text.strip().lower() == "up" + except requests.RequestException: + return False + +def set_up_worker_servers(servers): + ports = [5000, 5001, 5002, 5003, 5004] + processes = [(port, run_worker_in_a_port(port)) for port in ports] + while len(processes) > 0: + for port, process in processes: + # Check if the process is still running + if process.poll() is not None: # If poll() returns None, the process is still running + stdout, stderr = process.communicate() # Get the output and errors + logger.error(f"Process on port {port} has stopped!") + logger.error(f"STDOUT:\n{stdout}") + logger.error(f"STDERR:\n{stderr}") + raise RuntimeError(f"Failed to Start server on port: {port}") + if is_up(f"http://localhost:{port}"): + processes.remove((port, process)) + time.sleep(0.3) + logger.info(f"The following servers still need to start: {[p[0] for p in processes]}") + + +if __name__ == "__main__": + #ports = [5000,5001,5002,5003,5004] + #servers = [f"http://localhost:{port}" for port in ports] + hosts = ["cccxc425","cccxc417"]# ,'cccxc436'] + servers = [f"http://{server}.pok.ibm.com:5000" for server in hosts] + #set_up_worker_servers(servers) + set_verbosity("debug") + unitxt.settings.allow_unverified_code = True + dataset = load_dataset_cached(card="cards.openbook_qa", + split="test") + + dataset = dataset.select(range(100)) + + inference_model = MultiServersInferenceEngine( + model_name="google/flan-t5-small", + temperature=0.202, + max_new_tokens=256, + use_cache=True, + cache_batch_size=5, + workers_url=servers + ) + + start_time = time.time() + predictions = inference_model.infer(dataset) + end_time = time.time() + + logger.info(f"predictions contains {predictions.count(None)} Nones") + for p in predictions: + logger.info(f"prediction: {p}") diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index db6e6ed1e5..7fd3857fcf 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -127,58 +127,40 @@ } ], "results": { - "src/unitxt/catalog/cards/tablebench.json": [ + "src/unitxt/loaders.py": [ { - "type": "Hex High Entropy String", - "filename": "src/unitxt/catalog/cards/tablebench.json", - "hashed_secret": "fab1cac10b07d605b4ab506d750364b1327c8367", - "is_verified": false, - "line_number": 6 - } - ], - "src/unitxt/catalog/cards/tablebench_data_analysis.json": [ - { - "type": "Hex High Entropy String", - "filename": "src/unitxt/catalog/cards/tablebench_data_analysis.json", - "hashed_secret": "fab1cac10b07d605b4ab506d750364b1327c8367", - "is_verified": false, - "line_number": 6 - } - ], - "src/unitxt/catalog/cards/tablebench_fact_checking.json": [ - { - "type": "Hex High Entropy String", - "filename": "src/unitxt/catalog/cards/tablebench_fact_checking.json", - "hashed_secret": "fab1cac10b07d605b4ab506d750364b1327c8367", + "type": "Secret Keyword", + "filename": "src/unitxt/loaders.py", + "hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742", "is_verified": false, - "line_number": 6 + "line_number": 602 } ], - "src/unitxt/catalog/cards/tablebench_numerical_reasoning.json": [ + "src/unitxt/metrics.py": [ { "type": "Hex High Entropy String", - "filename": "src/unitxt/catalog/cards/tablebench_numerical_reasoning.json", - "hashed_secret": "fab1cac10b07d605b4ab506d750364b1327c8367", + "filename": "src/unitxt/metrics.py", + "hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889", "is_verified": false, - "line_number": 6 + "line_number": 71 } ], - "utils/.secrets.baseline": [ + "tests/library/test_loaders.py": [ { - "type": "Hex High Entropy String", - "filename": "utils/.secrets.baseline", - "hashed_secret": "ec3d2a39003a5b76da13be9b51af01cf6d616efb", + "type": "Secret Keyword", + "filename": "tests/library/test_loaders.py", + "hashed_secret": "8d814baafe5d8412572dc520dcab83f60ce1375c", "is_verified": false, - "line_number": 134 + "line_number": 125 }, { "type": "Secret Keyword", - "filename": "utils/.secrets.baseline", - "hashed_secret": "ec3d2a39003a5b76da13be9b51af01cf6d616efb", + "filename": "tests/library/test_loaders.py", + "hashed_secret": "42a472ac88cd8d43a2c5ae0bd0bdf4626cdaba31", "is_verified": false, - "line_number": 134 + "line_number": 135 } ] }, - "generated_at": "2025-04-02T14:52:32Z" + "generated_at": "2025-03-11T11:51:35Z" }