PyTorch入门程序
2022/2/20 22:30:38
本文主要是介绍PyTorch入门程序,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
转载自我的个人网站 https://wzw21.cn/2022/02/20/hello-pytorch/
在 PyTorch For Audio and Music Processing 入门代码的基础上添加了一些注释和新的内容
- Download dataset
- Create data loader
- Build model
- Train
- Save trained model
- Load model
- Predict
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor
def download_mnist_datasets(): train_data = datasets.MNIST( root="data", download=True, train=True, transform=ToTensor() ) val_data = datasets.MNIST( root="data", download=True, train=False, transform=ToTensor() ) return train_data, val_data
class SimpleNet(nn.Module): def __init__(self): # constructor super().__init__() self.flatten = nn.Flatten() self.dense_layers = nn.Sequential( nn.Linear(28*28, 256), # Fully Connected layer (input_shape, output_shape) nn.ReLU(), nn.Linear(256, 10) ) self.softmax = nn.Softmax(dim=1) def forward(self, input_data): flattened_data = self.flatten(input_data) logits = self.dense_layers(flattened_data) # logits here means the input of the final softmax predictions = self.softmax(logits) return predictions
Need more code than Tensorflow 2.x or Keras!
def train_one_epoch(model, data_loader, loss_fn, optimizer, device): model.train() # change to train mode loss_sum = 0. correct = 0 for inputs, targets in data_loader: inputs, targets = inputs.to(device), targets.to(device) # calculate loss predictions = model(inputs) # this will call forward function automatically loss = loss_fn(predictions, targets) # backpropagate loss and update weights optimizer.zero_grad() # reset grads loss.backward() # calculate grads optimizer.step() # update weights loss_sum += loss.item() # item() returns the value of this tensor as a standard Python number with torch.no_grad(): _, predictions_indexes = torch.max(predictions, 1) # get predicted indexes correct += torch.sum(predictions_indexes == targets) # or correct += (predictions.argmax(1) == targets).type(torch.float).sum().item() print(f"Train loss: {(loss_sum / len(data_loader)):.4f}, train accuracy: {(correct / len(data_loader.dataset)):.4f}") def val_one_epoch(model, data_loader, loss_fn, device): model.eval() # change to eval mode loss_sum = 0. correct = 0 with torch.no_grad(): for inputs, targets in data_loader: inputs, targets = inputs.to(device), targets.to(device) predictions = model(inputs) loss = loss_fn(predictions, targets) loss_sum += loss.item() _, predictions_indexes = torch.max(predictions, 1) correct += torch.sum(predictions_indexes == targets) print(f"Validation loss: {(loss_sum / len(data_loader)):.4f}, validation accuracy: {(correct / len(data_loader.dataset)):.4f}") def train(model, train_data_loader, val_data_loader, loss_fn, optimizer, device, epochs): for i in range(epochs): print(f"Epoch {i+1}") train_one_epoch(model, train_data_loader, loss_fn, optimizer, device) val_one_epoch(model, val_data_loader, loss_fn, device) print("-----------------------") print("Training is done")
def predict(model, input, target, class_mapping): # input's shape = torch.Size([1, 28, 28]) model.eval() # change to eval mode with torch.no_grad(): # don't have to calculate grads here predictions = model(input) # predictions' shape = torch.Size([1, 10]) predicted_index = predictions[0].argmax(0) predicted = class_mapping[predicted_index] expected = class_mapping[target] return predicted, expected
if torch.cuda.is_available(): device = "cuda" else: device = "cpu" print(f"Using {device} device") BATCH_SIZE = 32 EPOCHS = 10 LEARNING_RATE = .001 class_mapping = [ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ]
Using cuda device
# download MNIST dataset train_data, val_data = download_mnist_datasets() print("Dataset downloaded") # create a data loader for the train set train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE) val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
Dataset downloaded
# build model simple_net = SimpleNet().to(device) loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(simple_net.parameters(), lr=LEARNING_RATE) # train model train(simple_net, train_data_loader, val_data_loader, loss_fn, optimizer, device, EPOCHS) # save model torch.save(simple_net.state_dict(), "simple_net.pth") print("Model saved") # torch.save(model.state_dict(), "my_model.pth") # only save parameters # torch.save(model, "my_model.pth") # save the whole model # checkpoint = {"net": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch}
Epoch 1 Train loss: 1.5717, train accuracy: 0.9036 Validation loss: 1.5280, validation accuracy: 0.9388 ----------------------- Epoch 2 Train loss: 1.5148, train accuracy: 0.9506 Validation loss: 1.5153, validation accuracy: 0.9507 ----------------------- Epoch 3 Train loss: 1.5008, train accuracy: 0.9629 Validation loss: 1.5016, validation accuracy: 0.9625 ----------------------- Epoch 4 Train loss: 1.4924, train accuracy: 0.9707 Validation loss: 1.4958, validation accuracy: 0.9680 ----------------------- Epoch 5 Train loss: 1.4871, train accuracy: 0.9760 Validation loss: 1.4919, validation accuracy: 0.9702 ----------------------- Epoch 6 Train loss: 1.4837, train accuracy: 0.9789 Validation loss: 1.4884, validation accuracy: 0.9742 ----------------------- Epoch 7 Train loss: 1.4811, train accuracy: 0.9814 Validation loss: 1.4885, validation accuracy: 0.9736 ----------------------- Epoch 8 Train loss: 1.4787, train accuracy: 0.9837 Validation loss: 1.4896, validation accuracy: 0.9724 ----------------------- Epoch 9 Train loss: 1.4771, train accuracy: 0.9851 Validation loss: 1.4884, validation accuracy: 0.9739 ----------------------- Epoch 10 Train loss: 1.4758, train accuracy: 0.9863 Validation loss: 1.4889, validation accuracy: 0.9732 ----------------------- Training is done Model saved
# load model reloaded_simple_net = SimpleNet() state_dict = torch.load("simple_net.pth") reloaded_simple_net.load_state_dict(state_dict) # make an inference input, target = val_data[0][0], val_data[0][1] predicted, expected = predict(reloaded_simple_net, input, target, class_mapping) print(f"Predicted: '{predicted}', expected: '{expected}'")
Predicted: '7', expected: '7'
这篇关于PyTorch入门程序的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23增量更新怎么做?-icode9专业技术文章分享
- 2024-11-23压缩包加密方案有哪些?-icode9专业技术文章分享
- 2024-11-23用shell怎么写一个开机时自动同步远程仓库的代码?-icode9专业技术文章分享
- 2024-11-23webman可以同步自己的仓库吗?-icode9专业技术文章分享
- 2024-11-23在 Webman 中怎么判断是否有某命令进程正在运行?-icode9专业技术文章分享
- 2024-11-23如何重置new Swiper?-icode9专业技术文章分享
- 2024-11-23oss直传有什么好处?-icode9专业技术文章分享
- 2024-11-23如何将oss直传封装成一个组件在其他页面调用时都可以使用?-icode9专业技术文章分享
- 2024-11-23怎么使用laravel 11在代码里获取路由列表?-icode9专业技术文章分享
- 2024-11-22怎么实现ansible playbook 备份代码中命名包含时间戳功能?-icode9专业技术文章分享