diff --git a/examples/rl_order_execution/scripts/gen_training_orders.py b/examples/rl_order_execution/scripts/gen_training_orders.py index b03ce6e5a8..88e65463f9 100755 --- a/examples/rl_order_execution/scripts/gen_training_orders.py +++ b/examples/rl_order_execution/scripts/gen_training_orders.py @@ -17,11 +17,11 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> bool: if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5: return False - df["date"] = df["datetime"].dt.date.astype("datetime64") + df["date"] = df["datetime"].dt.date.astype("datetime64[ns]") df = df.set_index(["instrument", "datetime", "date"]) - df = df.groupby("date", group_keys=False).take(range(start_idx, end_idx)).droplevel(level=0) + df = df.groupby("date", group_keys=True).take(range(start_idx, end_idx)).droplevel(level=0) - order_all = pd.DataFrame(df.groupby(level=(2, 0), group_keys=False).mean().dropna()) + order_all = pd.DataFrame(df.groupby(level=(2, 0), group_keys=True).mean().dropna()) order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"] order_all = order_all[order_all["amount"] > 0.0] order_all["order_type"] = 0