💡 LSTM股价预测实战技巧
本章介绍了深度学习在量化中的应用,这里分享LSTM模型的实战技巧:
1. 数据预处理关键
# 使用对数收益率而非原始价格
df['log_returns'] = np.log(df['close'] / df['close'].shift(1))
# 标准化技巧:使用训练集统计量
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
train_scaled = scaler.fit_transform(train_data)
test_scaled = scaler.transform(test_data) # 注意:只用transform
2. 序列构建
def create_sequences(data, seq_len, pred_horizon):
"""创建输入序列和标签"""
X, y = [], []
for i in range(len(data) - seq_len - pred_horizon):
X.append(data[i:i+seq_len])
# 预测未来第N天的收益率
y.append(data[i+seq_len+pred_horizon-1])
return np.array(X), np.array(y)
# 建议参数
seq_len = 20 # 20天历史
pred_horizon = 5 # 预测5天后
3. 模型架构优化
class ImprovedLSTM(nn.Module):
def __init__(self, input_size, hidden_size=64, num_layers=2):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=0.2,
bidirectional=True # 双向LSTM
)
# 注意力机制
self.attention = nn.Linear(hidden_size * 2, 1)
self.fc = nn.Sequential(
nn.Linear(hidden_size * 2, 32),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(32, 1)
)
def forward(self, x):
lstm_out, _ = self.lstm(x)
# 注意力权重
attn_weights = torch.softmax(self.attention(lstm_out), dim=1)
context = torch.sum(attn_weights * lstm_out, dim=1)
return self.fc(context)
4. 防止过拟合
# 早停
from torch.utils.data import DataLoader
from copy import deepcopy
best_loss = float('inf')
patience = 10
counter = 0
for epoch in range(epochs):
train_loss = train_one_epoch()
val_loss = validate()
if val_loss < best_loss:
best_loss = val_loss
best_model = deepcopy(model.state_dict())
counter = 0
else:
counter += 1
if counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
model.load_state_dict(best_model)
5. 评估指标
def evaluate_predictions(predictions, actuals):
"""多维度评估"""
# 价格预测误差
mse = np.mean((predictions - actuals)**2)
mae = np.mean(np.abs(predictions - actuals))
# 方向准确率(更重要!)
pred_direction = np.diff(predictions) > 0
actual_direction = np.diff(actuals) > 0
direction_accuracy = np.mean(pred_direction == actual_direction)
print(f"MSE: {mse:.6f}")
print(f"MAE: {mae:.6f}")
print(f"方向准确率: {direction_accuracy:.2%}")
return {'mse': mse, 'mae': mae, 'direction_acc': direction_accuracy}
核心建议:
- 价格预测很难,方向预测更实用
- 方向准确率>55%就有价值
- 结合其他指标(如RSI、MACD)效果更好