import os
import psutil
import torch
import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces
import matplotlib.pyplot as plt
import MetaTrader5 as mt5
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
import threading
import time
# ==================== [НАСТРОЙКИ ДЛЯ РУЧНОГО ТЕСТА] ====================
SYMBOL = "BTCUSD"
CANDLE_COUNT = 30000 # Размер истории
WINDOW_SIZE = 40 # Сколько свечей видит агент
# Параметры PPO (крути их здесь)
LEARNING_RATE = 0.0002 # Скорость обучения (чуть снизили для стабильности)
ENTROPY_COEF = 0.02 # "Любопытство" (повысили, чтобы он чаще пробовал входить)
N_STEPS = 2048 # Шагов до обновления
BATCH_SIZE = 128 # Размер пакета данных
TOTAL_TIMESTEPS = 1500000 # Увеличили до 1.5 млн шагов
# Лимиты ядер
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
# =======================================================================
def prepare_data(symbol=SYMBOL, length=CANDLE_COUNT):
if not mt5.initialize():
print("Ошибка инициализации MT5")
return None
rates = mt5.copy_rates_from_pos(symbol, mt5.TIMEFRAME_M1, 1, length)
mt5.shutdown()
if rates is None:
print("Данные не получены")
return None
df = pd.DataFrame(rates)
df.columns = [c.lower() for c in df.columns]
# 1. RSI
delta = df['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
df['rsi'] = 100 - (100 / (1 + (gain / (loss + 1e-9))))
# 2. ATR (Нормализованный)
tr = pd.concat([df['high'] - df['low'], (df['high'] - df['close'].shift()).abs()], axis=1).max(axis=1)
df['atr'] = tr.rolling(window=14).mean() / df['close']
# 3. Volume (Нормализованный)
df['vol_norm'] = df['tick_volume'] / (df['tick_volume'].rolling(window=100).mean() + 1e-9)
# 4. Close (Нормализованный)
df['close_norm'] = df['close'] / df['close'].rolling(window=100).mean()
df.dropna(inplace=True)
# Возвращаем нужные фичи + оригинальную цену для расчетов
return df[['close_norm', 'rsi', 'atr', 'vol_norm', 'close']]
class BareTradingEnv(gym.Env):
def __init__(self, df):
super().__init__()
self.df = df.reset_index(drop=True)
self.action_space = spaces.Discrete(3) # 0=Wait, 1=Buy, 2=Sell
# Видит 4 индикатора
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(WINDOW_SIZE, 4), dtype=np.float32)
self.commission = 2.0 # Спред
self.initial_balance = 0.0 # Считаем просто пункты
self.reset()
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.current_step = WINDOW_SIZE
self.balance = 0.0
self.is_in_position = False
self.entry_price = 0.0
return self._get_obs(), {}
def _get_obs(self):
# Берем только первые 4 колонки (индикаторы)
obs = self.df.iloc[self.current_step - WINDOW_SIZE : self.current_step, :4].values
return obs.astype(np.float32)
def step(self, action):
current_price = self.df.loc[self.current_step, 'close']
reward = 0.0
# 1. Логика сделок
if action == 1 and not self.is_in_position: # КУПИТЬ
self.is_in_position = True
self.entry_price = current_price + self.commission
reward = -self.commission # Сразу штраф на спред
elif action == 2 and self.is_in_position: # ПРОДАТЬ (Закрыть лонг)
final_profit = current_price - self.entry_price
self.balance += final_profit
reward = 0 # Вся награда уже была получена по ходу движения
self.is_in_position = False
self.entry_price = 0
self.current_step += 1
terminated = self.current_step >= len(self.df) - 1
# 2. "Живая" награда (Dense Reward)
if self.is_in_position and not terminated:
next_price = self.df.loc[self.current_step, 'close']
step_movement = next_price - current_price # Изменение цены за 1 минуту
reward += step_movement
# 3. Штраф за лень (бездействие)
if not self.is_in_position:
reward -= 0.01 # Очень маленький минус, чтобы мотивировать искать вход
obs = self._get_obs() if not terminated else np.zeros((WINDOW_SIZE, 4))
return obs, reward, terminated, False, {"balance": self.balance}
def open_browser():
"""Функция для Selenium"""
time.sleep(7) # Ждем пока TensorBoard запустится
try:
options = webdriver.ChromeOptions()
options.add_experimental_option("detach", True)
driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=options)
driver.get("http://localhost:6006")
except Exception as e:
print(f"Не удалось открыть браузер: {e}")
if __name__ == '__main__':
# Снимаем ограничения ядер Windows
if os.name == 'nt':
psutil.Process().cpu_affinity(list(range(48)))
print("--- ЗАГРУЗКА ДАННЫХ ---")
data = prepare_data()
if data is None:
exit()
split = int(len(data) * 0.8)
train_df = data.iloc[:split]
test_df = data.iloc[split:]
# Запуск 48 потоков (на твоем Dual Xeon)
n_envs = os.cpu_count()
print(f"Использую {n_envs} ядер...")
train_env = SubprocVecEnv([lambda: BareTradingEnv(train_df) for _ in range(48)])
# Запуск TensorBoard в фоне
os.system("start /B tensorboard --logdir ./ppo_dense_logs/")
# Запуск Selenium в отдельном потоке
threading.Thread(target=open_browser, daemon=True).start()
# Создание модели
model = PPO(
"MlpPolicy",
train_env,
learning_rate=LEARNING_RATE,
ent_coef=ENTROPY_COEF,
n_steps=N_STEPS,
batch_size=BATCH_SIZE,
tensorboard_log="./ppo_dense_logs/",
verbose=1
)
print(f"--- ОБУЧЕНИЕ ЗАПУЩЕНО ({TOTAL_TIMESTEPS} шагов) ---")
model.learn(total_timesteps=TOTAL_TIMESTEPS, progress_bar=True)
model.save("ppo_dense_trading_model")
# --- ФИНАЛЬНЫЙ ТЕСТ ---
test_env = BareTradingEnv(test_df)
obs, _ = test_env.reset()
prices, buys, sells = [], [], []
print("Запуск финального теста...")
for i in range(len(test_df) - WINDOW_SIZE - 1):
action, _ = model.predict(obs, deterministic=True)
price = test_env.df.loc[test_env.current_step, 'close']
was_in = test_env.is_in_position
obs, _, done, _, _ = test_env.step(action)
if not was_in and test_env.is_in_position: buys.append((i, price))
if was_in and not test_env.is_in_position: sells.append((i, price))
prices.append(price)
if done: break
# Визуализация
plt.style.use('dark_background')
plt.figure(figsize=(15, 7))
plt.plot(prices, alpha=0.3, color='gray')
if buys:
b_idx, b_pr = zip(*buys)
plt.scatter(b_idx, b_pr, marker='^', color='lime', s=100, label='BUY')
if sells:
s_idx, s_pr = zip(*sells)
plt.scatter(s_idx, s_pr, marker='v', color='red', s=100, label='SELL')
plt.title(f"Result. Profit: {test_env.balance:.2f} points")
plt.legend()
plt.show() Click Run or press shift + ENTER to run code