강화학습 PPO 모델을 이용한 자동매매 2
import numpy as np
import pandas as pd
import gym
from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from sklearn.model_selection import train_test_split
from datetime import datetime
from pytz import timezone
from sklearn.model_selection import ParameterGrid
import warnings
# 특정 경고 메시지 무시
warnings.filterwarnings("ignore", category=UserWarning)
class TradingEnv(gym.Env):
def __init__(self, df):
super(TradingEnv, self).__init__()
self.df = df
self.current_step = 0
self.df_numeric = self._convert_to_numeric(self.df)
self.action_space = spaces.Discrete(2) # Buy or Hold
self.observation_space = spaces.Box(low=0, high=1, shape=(self.df_numeric.shape[1],), dtype=np.float32)
self.initial_balance = 10000
self.balance = self.initial_balance
self.net_worth = self.initial_balance
self.max_net_worth = self.initial_balance
self.shares_held = 0
self.current_price = self.df.iloc[self.current_step]['close'] # 현재 스텝의 가격으로 초기화
self.price_history = self.df['close'].tolist() # 전체 가격 데이터를 리스트로 저장
def _convert_to_numeric(self, df):
df_numeric = df.copy()
df_numeric = df_numeric.filter(regex='^m_')
df_numeric.reset_index(drop=True, inplace=True) # 인덱스를 드롭하고 리셋
for column in df_numeric.columns:
df_numeric[column] = pd.to_numeric(df_numeric[column], errors='coerce')
df_numeric.fillna(0, inplace=True)
return df_numeric
def reset(self):
self.balance = self.initial_balance
self.net_worth = self.initial_balance
self.max_net_worth = self.initial_balance
self.shares_held = 0
self.current_step = 0
self.current_price = self.df.iloc[self.current_step]['close'] # 리셋 시 현재 가격 초기화
return self._next_observation()
def _next_observation(self):
obs = self.df_numeric.iloc[self.current_step].values
obs_max = obs.max() if obs.max() != 0 else 1 # Prevent division by zero
obs = obs / obs_max
return obs
def step(self, action):
self.current_step += 1
self.current_price = self.df.iloc[self.current_step]['close'] # 매 스텝마다 현재 가격 업데이트
self.low_price = self.df.iloc[self.current_step]['low'] # 매 스텝마다 현재 가격 업데이트
self.current_time = self.df.index[self.current_step] # 매 스텝마다 현재 가격 업데이트
if action == 1: # Buy
self.shares_held += self.balance / self.current_price
self.balance = 0
elif action == 0: # Hold
pass
self.net_worth = self.balance + self.shares_held * self.current_price
self.max_net_worth = max(self.max_net_worth, self.net_worth)
# 1시간 후 가격 변동을 확인하여 보상을 계산
reward = self.calculate_reward(action)
done = self.current_step >= len(self.df) - 1
obs = self._next_observation()
return obs, reward, done, {}
def calculate_reward(self, action):
'''
현재 가격에서 시작하여 다음 12 스텝 동안의 가격을 모두 체크하며, 그 중 하나라도 5% 이상 상승한 경우 보상으로 1을 반환합니다. 1시간 동안 5% 이상 상승한 적이 없다면 보상으로 0을 반환합니다.
즉, buy 의견을 제시한것이 잘했는지를 평가할때, reward 보상으로 학습을 시킨다.
'''
end_step = min(self.current_step + 12, len(self.df) - 1) # 1시간 = 12 steps (assuming 5-minute intervals)
reward = 0
if action == 1: # Buy 액션일 경우에만 보상 계산
for step in range(1, end_step - self.current_step + 1):
future_price = self.price_history[self.current_step + step]
price_increase = (future_price - self.current_price) / self.current_price
if (step - self.current_step) <= 5:
if future_price < self.low_price: # 5봉 이내(30분이내) 현재가보다 하락하고 있으면, reward 없음.
break
if price_increase >= 0.05: # 5% 이상 상승
print("%s self.current_step:%s" % (self.current_time, self.current_step))
print("for range step:%s" % (step))
print("future_price:%s" % (self.price_history[self.current_step + step]))
print("price_increase:%s" % ((future_price - self.current_price) / self.current_price))
print("reward = 1")
reward = 1
break
return reward # 1시간 동안 5% 이상 상승하지 않음
def optimize_ppo(data, param_grid, model_path="ppo_trading_model"):
env = TradingEnv(data)
best_model = None
best_reward = -float('inf')
for params in ParameterGrid(param_grid):
model = PPO('MlpPolicy', env, verbose=1, **params)
model.learn(total_timesteps=10000)
total_rewards = evaluate_model(model, data)
if total_rewards > best_reward:
best_reward = total_rewards
best_model = model
best_model.save(model_path)
return best_model
def train_model(data, model_path="ppo_trading_model"):
env = TradingEnv(data)
try:
model = PPO.load(model_path)
print("Model loaded successfully. Continuing training...")
except:
model = PPO('MlpPolicy', env, verbose=1)
print("New model initialized.")
model.set_env(env)
param_grid = {
'n_steps': [128, 256, 512],
'learning_rate': [1e-3, 1e-4, 1e-5],
'batch_size': [128, 256], # 변경된 부분: 128의 배수로 설정
}
best_model = optimize_ppo(data, param_grid, model_path)
best_model.learn(total_timesteps=10000)
best_model.save(model_path)
return best_model
def load_model(model_path="ppo_trading_model"):
return PPO.load(model_path)
def evaluate_model(model, data):
env = TradingEnv(data)
obs = env.reset()
total_rewards = 0
done = False
while not done:
action, _states = model.predict(obs)
obs, reward, done, _ = env.step(action)
total_rewards += reward
return total_rewards
def main():
ticker = 'XEM'
chart_intervals = 'minute5'
current_time = pd.to_datetime(datetime.now(timezone('Asia/Seoul'))).strftime("%Y-%m-%d %H:%M:%S")
chart_data = save_db_market_infos(ticker=ticker, chart_intervals=chart_intervals, current_time=current_time)
chart_data.set_index('time', inplace=True)
strategy_data = get_strategy_mst_data()
if chart_data is not None:
strategy_chart_data = calculate_indicators(chart_data, ticker)
strategy_chart_data_df = pd.DataFrame([strategy_chart_data])
train_data, test_data = train_test_split(chart_data, test_size=0.2, shuffle=False)
model = train_model(train_data)
total_rewards = evaluate_model(model, test_data)
print(f"Total Rewards: {total_rewards}")
if isinstance(strategy_chart_data_df, pd.DataFrame):
obs = strategy_chart_data_df.values.flatten().astype(np.float32)
obs = np.expand_dims(obs, axis=0)
action, _states = model.predict(obs)
print("Buy Signal:", "Yes" if action == 1 else "No")
else:
print("Error: strategy_chart_data is not a DataFrame")
else:
print("Error: chart_data is None")
if __name__ == "__main__":
main()
'AI주식자동매매' 카테고리의 다른 글
GPT 4o Vision 활용하기(3) (3) | 2024.10.06 |
---|---|
GPT 4o Vision 활용하기(2) (0) | 2024.10.06 |
코인/주식 자동매수에 머신러닝 강화학습 PPO적용하기 1 (0) | 2024.08.09 |
차트를 찍는 남자 !!! (2) | 2024.06.27 |
골드스탁(GoldStock) EXE 실행파일 생성 (0) | 2023.08.08 |