Machine Learning Tutorial for Beginners — PyTorch MNIST: Train, Save and Deploy

Jiuhe Wang
3 min readApr 17, 2022

Hello, welcome to Machine Learning Tutorial for Beginners.

In today’s tutorial, we’ll learn how to train a mnist model using the offical script, save the model and serve the model using Pinferencia.

Never heard of Pinferencia? It’s not late. Check it out at GitHub

Train and Save the model

Visit PyTorch Examples — MNIST, download the files.

Run below commands to install and train the model:

After the training is finished, you will have a folder structure as below. A mnist_cnn.pt file is created.

The structure of the model is in the main.py

Load the model

Now let’s load the model:

Because the pt file is just a state dict. We need the initialize the model first, then load the state dict.

Serve it Through REST API

Without deployment, how could a machine learning tutorial be complete?

First, let’s install Pinferencia.

pip install “pinferencia[uvicorn]”

If you haven’t heard of Pinferencia go to its github page or its homepage to check it out, it’s an amazing library help you deploy your model with ease.

Let’s create a file app.py in the same folder.

Run the service, and wait for it to load the model and start the server:

uvicorn app:service — reload

Test the service:

Let’s use the base64 encoded image, you can checkout the image at Best Online Base64 to Image Decoder / Converter (codebeautify.org)

Its base64 encoded image string is:

Using python requests:

Save the codes as test.py :

import requestsresponse = requests.post(
url="http://localhost:8000/v1/models/mnist/predict",
json={"data": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+uhfwXqy2Ph25VYnPiB3SzhUkPlXCfNkAAEsCCCeOeKx9RsLjStUu9Ou1C3NpM8Eqg5AdSVIz35FVqK9xl0HXhb/C20sdMubjTLMQXs11AhkRXmmDsCwzgAYPpz+XI/GrSLrTfiVqNzPapbw3xE8AWQNvUAKXOOmWVjg+teeUV2fgXxd4hsPE2hWEGuX8Vh9uhja3Fw3lbGcBhtzjGCad8XI7iL4p68twHDGcMm45+QqCuPbBFcVRRU97fXepXb3d9dT3VzJjfNPIXdsAAZY8nAAH4VBX/9k="},
)
print("Prediction:", response.json())

Run python test.py, response:

Cool~~ Not yet, even cooler:

You can use the swagger ui at http://127.0.0.1:8000 (the server’s address) to try the prediction.

Extra bonus: Sum Up the MNIST Images

--

--