SoFunction
Updated on 2024-12-16

pytorch 6 batch_train Batch training operation

Look at the code.

import torch
import  as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 8 simultaneous data feeds to the network per use
x = (1, 10, 10)       # this is x data (torch tensor)
y = (10, 1, 10)       # this is y data (torch tensor)
torch_dataset = (x, y)
loader = (
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # Set up not to randomly shuffle data random shuffle for training
    num_workers=2,              # Extract data using two processes, subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # Use the entire dataset 3 times, train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , all data utilized three times

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

Addendum: pytorch batch training bugs

Problem Description:

When doing pytorch neural network batch training, it sometimes reports an error

TypeError: batch must contain tensors, numbers, dicts or lists; found <class ''>

Solution:

Step one:

Inspection (focus !!!!!) :

train_dataset = (train_x, train_y)

train_x, and train_y format, requires the tensor class, my first error was because I passed in variable

The data can be turned into a tensor class in this way:

train_x = (train_x)

Step two:

train_loader = (
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

Instantiate a DataLoader object

Step Three:

    for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

This way you can batch train

One thing to note: the output of train_loader is a tensor, which needs to be turned into a Variable when training the network

The above is a personal experience, I hope it can give you a reference, and I hope you can support me more.