新手报到
import torchimport torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import RobustScaler
# 1. 增强版数据预处理
def load_and_preprocess_data():
try:
# Load the data
data = pd.read_csv('CS2_35.csv', na_values=[' ', 'NA', 'N/A', 'NaN'])
# Ensure required columns exist
required_cols = ['cycle', 'resistance', 'CCCT', 'CVCT', 'SoH']
if not all(col in data.columns for col in required_cols):
raise ValueError("CSV file is missing required columns.")
# Filter and clean data
data = data.dropna()
# Scale features and target
x_scaler = RobustScaler()
y_scaler = RobustScaler()
X = x_scaler.fit_transform(data[['resistance', 'CCCT', 'CVCT']])
y = y_scaler.fit_transform(data[['SoH']])
return X, y, x_scaler, y_scaler
except Exception as e:
print(f"Error loading or preprocessing data: {e}")
raise# Re-raise the exception to stop execution
# 2. 改进的模型架构
class EnhancedModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 32),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Linear(32, 1)
)
def forward(self, x):
return self.net(x)
# 3. 优化的训练流程
def train_and_evaluate():
# 加载数据
X, y, x_scaler, y_scaler = load_and_preprocess_data()
# 数据集划分(分层抽样)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, shuffle=True)
# 转换为Tensor
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.FloatTensor(y_train)
y_test = torch.FloatTensor(y_test)
# 创建DataLoader
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 初始化模型
model = EnhancedModel(input_size=X.shape[1])
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
# 训练循环
print("=== 开始训练 ===")
best_r2 = -np.inf
patience = 30
patience_counter = 0
for epoch in range(500):
model.train()
epoch_loss = 0
for batch_x, batch_y in train_loader:
optimizer.zero_grad()
outputs = model(batch_x)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
scheduler.step()
# 验证
model.eval()
with torch.no_grad():
y_pred = model(X_test)
y_test_orig = y_scaler.inverse_transform(y_test.numpy())
y_pred_orig = y_scaler.inverse_transform(y_pred.numpy())
current_r2 = r2_score(y_test_orig, y_pred_orig)
# 早停机制基于R²
if current_r2 > best_r2:
best_r2 = current_r2
patience_counter = 0
torch.save(model.state_dict(), 'best_model.pth')
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch + 1:3d} | Loss: {epoch_loss / len(train_loader):.4f} | R²: {current_r2:.4f}")
# 加载最佳模型
model.load_state_dict(torch.load('best_model.pth'))
# 最终评估
model.eval()
with torch.no_grad():
y_pred = model(X_test)
y_test_orig = y_scaler.inverse_transform(y_test.numpy())
y_pred_orig = y_scaler.inverse_transform(y_pred.numpy())
r2 = r2_score(y_test_orig, y_pred_orig)
mae = mean_absolute_error(y_test_orig, y_pred_orig)
rmse = np.sqrt(mean_squared_error(y_test_orig, y_pred_orig))
mape = np.mean(np.abs((y_test_orig - y_pred_orig) / y_test_orig)) * 100
print("\n=== 最终测试结果 ===")
print(f"MAE: {mae:.4f}")
print(f"MAPE: {mape:.4f}%")
print(f"RMSE: {rmse:.4f}")
print(f"R² Score: {r2:.4f}")
# 可视化
plt.figure(figsize=(15, 5))
# 预测 vs 真实
plt.subplot(1, 3, 1)
plt.scatter(y_test_orig, y_pred_orig, alpha=0.5)
plt.plot(,
, 'r--')
plt.xlabel('True SoH')
plt.ylabel('Predicted SoH')
plt.title(f'Prediction (R²={r2:.2f})')
# 趋势对比
plt.subplot(1, 3, 2)
plt.plot(y_test_orig, label='True')
plt.plot(y_pred_orig, label='Predicted')
plt.xlabel('Sample Index')
plt.ylabel('SoH')
plt.legend()
plt.title('Trend Comparison')
# 误差分布
plt.subplot(1, 3, 3)
errors = y_test_orig - y_pred_orig
plt.hist(errors, bins=30)
plt.xlabel('Prediction Error')
plt.ylabel('Frequency')
plt.title('Error Distribution')
plt.tight_layout()
plt.savefig('results.png', dpi=300)
plt.show()
# Create a DataFrame with true and predicted values
results_df = pd.DataFrame({
'True_SoH': y_test_orig.flatten(),
'Predicted_SoH': y_pred_orig.flatten()
})
# Save to CSV
results_df.to_csv('LSTM_predictions.csv', index=False)
print("Predictions saved to LSTM_predictions.csv")
if __name__ == "__main__":
train_and_evaluate()
http://bbs.52pcgame.net/data/attachment/album/201809/09/194700b4ocjybubd1eo4k6.gif
欢迎新人,请勿发生以下行为,否則将受到扣分,禁言以致封号的惩罚:
1.广告:不能确定自己发布的链接是否为广告的请提前联系当区版主或者站点管理员。
2.无意义回复:纯数字,纯字母 (例: 666666666、ASDFGHJK等类似的内容),纯表情,连续回帖三连以上,或者不同帖子大量复制粘贴一样的回复等不尊重发帖作者的行为均可能会被版主定义为无意义回复,请注意不同版主的执法尺度。
3.严禁调侃时政:例:严禁注册 我国领导人名作为ID,禁用我国当代领导人头像(禁止“各类暴力膜”), 禁谈政治、时事, 各区版规都有所不同,详细请到版区阅读。(ps:实际上各区都严禁谈论政治,只是某些区问题较突出,新人们要特别注意)
4.人身攻击: 严禁发表对会员进行人身攻击、谩骂、挑衅等言论。
5.法律法规:请自觉遵守《全国人大常委会关于维护互联网安全的决定》、《互联网信息服务管理办法》、《互联网电子公告服务管理规定》及中华人民共和国其他各项有关法律法规。
以上就是最基本的原则,当然新人最好还是先看一下条例 【会员严禁发表与回复的内容】{:4_189:}
相信新人们都会好好遵守的,有意见可到或者老账号找回问题请到站务办公室反馈{:4_136:}。
有意义,怎么会没有意义呢 {:4_97:}{:4_97:} 我错了,球球手下留情吧 {:4_104:}{:4_100:}{:4_146:}{:4_170:} {:4_104:}{:4_145:}{:4_182:}{:4_182:}{:4_159:}{:4_160:}{:4_184:}
页:
[1]