Skip to Content
Quickstart

Getting Started with Trainwave

This guide will walk you through setting up your first machine learning training job on Trainwave. We’ll use a simple PyTorch example, but you can adapt these steps for any ML framework.

Prerequisites

  • Python 3.10 or later
  • pip package manager
  • A Trainwave account (Sign up here)

Step 1: Install the CLI

The Trainwave CLI is your primary tool for managing training jobs. Install it using pip:

pip install trainwave-cli

Step 2: Authenticate

Log in to your Trainwave account through the CLI:

wave auth login

This will open your browser for authentication. Alternatively, you can generate an API token in the dashboard and set it directly:

wave auth set-token <your-api-token>

Verify your authentication:

wave auth whoami

Step 3: Create an Organization

Organizations help you manage projects and billing. Create one through the web interface.

Example organization structure:

MyOrg ├── Project A (ML Research) ├── Project B (Production Models) └── Project C (Experiments)

Step 4: Create a Project

Create a new project in your organization:

  1. Go to the Projects dashboard
  2. Click “New Project”
  3. Provide a name for your project
  4. Note your project ID — you’ll need it in your configuration

Step 5: Configure Your Training Job

Create a trainwave.toml in your project directory. Here’s a complete example for a PyTorch training job:

# Basic Configuration name = "pytorch-mnist" project = "p-your-project-id" description = "Training MNIST classifier using PyTorch" # Resource Configuration gpu_type = "RTX A5000" gpus = 1 cpu_cores = 4 memory_gb = 16 hdd_size_mb = 51200 # Runtime Configuration image = "trainwave/pytorch:2.3.1" setup_command = """ pip install -r requirements.txt wandb login ${WANDB_API_KEY} """ run_command = "python train.py" # Environment Variables [env_vars] WANDB_API_KEY = "${WANDB_API_KEY}" PYTORCH_CUDA_ALLOC_CONF = "max_split_size_mb:512" # Optional Settings expires = "4h"

Step 6: Prepare Your Code

Here’s a minimal example of a PyTorch training script (train.py):

import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms import wandb # Initialize wandb wandb.init(project="mnist-example") # Define model class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3) self.conv2 = nn.Conv2d(32, 64, 3) self.fc1 = nn.Linear(1600, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = torch.max_pool2d(x, 2) x = torch.flatten(x, 1) x = torch.relu(self.fc1(x)) return self.fc2(x) # Setup training device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Net().to(device) optimizer = optim.Adam(model.parameters()) criterion = nn.CrossEntropyLoss() # Load data transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) dataset = datasets.MNIST('data', train=True, download=True, transform=transform) loader = torch.utils.data.DataLoader(dataset, batch_size=64) # Training loop for epoch in range(10): for batch_idx, (data, target) in enumerate(loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() if batch_idx % 100 == 0: wandb.log({ "loss": loss.item(), "epoch": epoch }) # Save model torch.save(model.state_dict(), "model.pt") wandb.save("model.pt")

And the corresponding requirements.txt:

torch>=2.0.0 torchvision>=0.15.0 wandb>=0.15.0

Step 7: Launch Your Job

Launch your training job from the directory containing your trainwave.toml:

wave jobs launch

Monitor your job:

# View job status wave jobs status # Stream logs wave jobs logs -f

Next Steps

Common Issues and Solutions

GPU Not Detected

If your code can’t detect the GPU:

import torch print(f"CUDA available: {torch.cuda.is_available()}") print(f"GPU count: {torch.cuda.device_count()}")

Out of Memory

Add to your trainwave.toml:

env_vars.PYTORCH_CUDA_ALLOC_CONF = "max_split_size_mb:512"

Job Timeout

Extend the job timeout in trainwave.toml:

expires = "12h"

Need help? Join our Discord community  or email [email protected].

Last updated on