From 209f208ff6f7ec27ec5fd830ba710eaea087880b Mon Sep 17 00:00:00 2001 From: Linlang Date: Fri, 26 Sep 2025 12:14:49 +0800 Subject: [PATCH] fix: gen_training_orders.py bugs --- examples/rl_order_execution/scripts/gen_training_orders.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/rl_order_execution/scripts/gen_training_orders.py b/examples/rl_order_execution/scripts/gen_training_orders.py index b03ce6e5a8..88e65463f9 100755 --- a/examples/rl_order_execution/scripts/gen_training_orders.py +++ b/examples/rl_order_execution/scripts/gen_training_orders.py @@ -17,11 +17,11 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> bool: if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5: return False - df["date"] = df["datetime"].dt.date.astype("datetime64") + df["date"] = df["datetime"].dt.date.astype("datetime64[ns]") df = df.set_index(["instrument", "datetime", "date"]) - df = df.groupby("date", group_keys=False).take(range(start_idx, end_idx)).droplevel(level=0) + df = df.groupby("date", group_keys=True).take(range(start_idx, end_idx)).droplevel(level=0) - order_all = pd.DataFrame(df.groupby(level=(2, 0), group_keys=False).mean().dropna()) + order_all = pd.DataFrame(df.groupby(level=(2, 0), group_keys=True).mean().dropna()) order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"] order_all = order_all[order_all["amount"] > 0.0] order_all["order_type"] = 0