Updated vimtex and snippets
This commit is contained in:
parent
84a6b218fe
commit
10bf23a6ec
44 changed files with 4209 additions and 69 deletions
131
LuaSnip/python/torch.lua
Normal file
131
LuaSnip/python/torch.lua
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
local helpers = require('personal.luasnip-helper-funcs')
|
||||
local get_visual = helpers.get_visual
|
||||
|
||||
local line_begin = require("luasnip.extras.expand_conditions").line_begin
|
||||
|
||||
return
|
||||
{
|
||||
-- COMMON IMPORTS
|
||||
s({trig="itorch"},
|
||||
fmt(
|
||||
[[
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
]],
|
||||
{
|
||||
}
|
||||
),
|
||||
{condition = line_begin}
|
||||
),
|
||||
-- NETWORK MODEL TEMPLATE
|
||||
s({trig="model"},
|
||||
fmta(
|
||||
[[
|
||||
class FooNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(FooNet, self).__init__()
|
||||
<>
|
||||
|
||||
def forward(self, x):
|
||||
<>
|
||||
]],
|
||||
{
|
||||
i(1),
|
||||
i(2)
|
||||
}
|
||||
),
|
||||
{condition = line_begin}
|
||||
),
|
||||
-- CUSTOM DATASET TEMPLATE
|
||||
s({trig="dataset"},
|
||||
fmta(
|
||||
[[
|
||||
class FooDataset(Dataset):
|
||||
def __init__(self, ...):
|
||||
<>
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Returns the (feature vector, label) tuple at index `index`
|
||||
<>
|
||||
|
||||
def __len__(self):
|
||||
# Return number of instances in dataset
|
||||
<>
|
||||
]],
|
||||
{
|
||||
i(1),
|
||||
i(2),
|
||||
i(3)
|
||||
}
|
||||
),
|
||||
{condition = line_begin}
|
||||
),
|
||||
-- SGD OPTIMIZER
|
||||
s({trig="optim"},
|
||||
fmta(
|
||||
[[
|
||||
optim = torch.optim.SGD(model.parameters(), lr=<>)
|
||||
]],
|
||||
{
|
||||
i(1),
|
||||
}
|
||||
),
|
||||
{condition = line_begin}
|
||||
),
|
||||
-- TRAINING LOOP TEMPLATE
|
||||
s({trig="train"},
|
||||
fmta(
|
||||
[[
|
||||
def train_loop(dataloader, model, loss_fn, optim):
|
||||
N = len(dataloader.dataset)
|
||||
|
||||
# Loop over all minibatches in dataset
|
||||
for mb, (X, y) in enumerate(dataloader):
|
||||
# Compute prediction and loss
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Log loss and number of instances trained
|
||||
if mb % <> == 0:
|
||||
loss, n = loss.item(), mb * len(X)
|
||||
print("loss: {:.7f} [{:5d}/{:5d}]".format(loss, n, N))
|
||||
|
||||
]],
|
||||
{
|
||||
i(1, "100"),
|
||||
}
|
||||
),
|
||||
{condition = line_begin}
|
||||
),
|
||||
-- TEST LOOP TEMPLATE
|
||||
s({trig="test"},
|
||||
fmta(
|
||||
[[
|
||||
def test_loop(dataloader, model, loss_fn):
|
||||
N = len(dataloader.dataset)
|
||||
num_batches = len(dataloader)
|
||||
test_loss = 0
|
||||
correct_preds = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for X, y in dataloader:
|
||||
pred = model(X)
|
||||
test_loss += loss_fn(pred, y).item()
|
||||
correct_preds += (pred.argmax(1) == y).type(torch.float).sum().item()
|
||||
|
||||
test_loss /= num_batches
|
||||
print("Test Error: \n Accuracy: {:.1f}%\n Avg loss per minibatch: {:8f} \n".format((100*correct_preds/N), test_loss))
|
||||
]],
|
||||
{ }
|
||||
),
|
||||
{condition = line_begin}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue