PyTorch入门程序

2022/2/20 22:30:38

本文主要是介绍PyTorch入门程序,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

转载自我的个人网站 https://wzw21.cn/2022/02/20/hello-pytorch/

在 PyTorch For Audio and Music Processing 入门代码的基础上添加了一些注释和新的内容

  1. Download dataset
  2. Create data loader
  3. Build model
  4. Train
  5. Save trained model
  6. Load model
  7. Predict
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
def download_mnist_datasets():
  train_data = datasets.MNIST(
      root="data",
      download=True,
      train=True,
      transform=ToTensor() 
  )
  val_data = datasets.MNIST(
      root="data",
      download=True,
      train=False,
      transform=ToTensor() 
  )
  return train_data, val_data
class SimpleNet(nn.Module):

  def __init__(self): # constructor
    super().__init__()
    self.flatten = nn.Flatten()
    self.dense_layers = nn.Sequential(
        nn.Linear(28*28, 256), # Fully Connected layer (input_shape, output_shape)
        nn.ReLU(),
        nn.Linear(256, 10)
    )
    self.softmax = nn.Softmax(dim=1)

  def forward(self, input_data):
    flattened_data = self.flatten(input_data)
    logits = self.dense_layers(flattened_data) # logits here means the input of the final softmax
    predictions = self.softmax(logits)
    return predictions

Need more code than Tensorflow 2.x or Keras!

def train_one_epoch(model, data_loader, loss_fn, optimizer, device):
  model.train() # change to train mode
  loss_sum = 0.
  correct = 0
  for inputs, targets in data_loader:
    inputs, targets = inputs.to(device), targets.to(device)

    # calculate loss
    predictions = model(inputs) # this will call forward function automatically
    loss = loss_fn(predictions, targets)

    # backpropagate loss and update weights
    optimizer.zero_grad() # reset grads
    loss.backward() # calculate grads
    optimizer.step() # update weights

    loss_sum += loss.item() # item() returns the value of this tensor as a standard Python number
    
    with torch.no_grad():
      _, predictions_indexes = torch.max(predictions, 1) # get predicted indexes
      correct += torch.sum(predictions_indexes == targets)
      # or correct += (predictions.argmax(1) == targets).type(torch.float).sum().item()

  print(f"Train loss: {(loss_sum / len(data_loader)):.4f}, train accuracy: {(correct / len(data_loader.dataset)):.4f}")

def val_one_epoch(model, data_loader, loss_fn, device):
  model.eval() # change to eval mode
  loss_sum = 0.
  correct = 0
  with torch.no_grad():
    for inputs, targets in data_loader:
      inputs, targets = inputs.to(device), targets.to(device)

      predictions = model(inputs)
      loss = loss_fn(predictions, targets)

      loss_sum += loss.item()
      _, predictions_indexes = torch.max(predictions, 1)
      correct += torch.sum(predictions_indexes == targets)

  print(f"Validation loss: {(loss_sum / len(data_loader)):.4f}, validation accuracy: {(correct / len(data_loader.dataset)):.4f}")

def train(model, train_data_loader, val_data_loader, loss_fn, optimizer, device, epochs):
  for i in range(epochs):
    print(f"Epoch {i+1}")
    train_one_epoch(model, train_data_loader, loss_fn, optimizer, device)
    val_one_epoch(model, val_data_loader, loss_fn, device)
    print("-----------------------")
  print("Training is done")
def predict(model, input, target, class_mapping):
  # input's shape = torch.Size([1, 28, 28])
  model.eval() # change to eval mode
  with torch.no_grad(): # don't have to calculate grads here
    predictions = model(input)
    # predictions' shape = torch.Size([1, 10])
    predicted_index = predictions[0].argmax(0)
    predicted = class_mapping[predicted_index]
    expected = class_mapping[target]
  return predicted, expected
if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"
print(f"Using {device} device")

BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = .001

class_mapping = [
  "0",
  "1",
  "2",
  "3",
  "4",
  "5",
  "6",
  "7",
  "8",
  "9"
]
Using cuda device
# download MNIST dataset
train_data, val_data = download_mnist_datasets()
print("Dataset downloaded")

# create a data loader for the train set
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE)
val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
Dataset downloaded
# build model
simple_net = SimpleNet().to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(simple_net.parameters(), lr=LEARNING_RATE)

# train model
train(simple_net, train_data_loader, val_data_loader, loss_fn, optimizer, device, EPOCHS)

# save model
torch.save(simple_net.state_dict(), "simple_net.pth")
print("Model saved")
# torch.save(model.state_dict(), "my_model.pth")  # only save parameters
# torch.save(model, "my_model.pth")  # save the whole model
# checkpoint = {"net": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch}
Epoch 1
Train loss: 1.5717, train accuracy: 0.9036
Validation loss: 1.5280, validation accuracy: 0.9388
-----------------------
Epoch 2
Train loss: 1.5148, train accuracy: 0.9506
Validation loss: 1.5153, validation accuracy: 0.9507
-----------------------
Epoch 3
Train loss: 1.5008, train accuracy: 0.9629
Validation loss: 1.5016, validation accuracy: 0.9625
-----------------------
Epoch 4
Train loss: 1.4924, train accuracy: 0.9707
Validation loss: 1.4958, validation accuracy: 0.9680
-----------------------
Epoch 5
Train loss: 1.4871, train accuracy: 0.9760
Validation loss: 1.4919, validation accuracy: 0.9702
-----------------------
Epoch 6
Train loss: 1.4837, train accuracy: 0.9789
Validation loss: 1.4884, validation accuracy: 0.9742
-----------------------
Epoch 7
Train loss: 1.4811, train accuracy: 0.9814
Validation loss: 1.4885, validation accuracy: 0.9736
-----------------------
Epoch 8
Train loss: 1.4787, train accuracy: 0.9837
Validation loss: 1.4896, validation accuracy: 0.9724
-----------------------
Epoch 9
Train loss: 1.4771, train accuracy: 0.9851
Validation loss: 1.4884, validation accuracy: 0.9739
-----------------------
Epoch 10
Train loss: 1.4758, train accuracy: 0.9863
Validation loss: 1.4889, validation accuracy: 0.9732
-----------------------
Training is done
Model saved
# load model
reloaded_simple_net = SimpleNet()
state_dict = torch.load("simple_net.pth")
reloaded_simple_net.load_state_dict(state_dict)

# make an inference
input, target = val_data[0][0], val_data[0][1]
predicted, expected = predict(reloaded_simple_net, input, target, class_mapping)
print(f"Predicted: '{predicted}', expected: '{expected}'")
Predicted: '7', expected: '7'


这篇关于PyTorch入门程序的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程