require 'torch' require 'nn' require 'optim' require 'image' util = paths.dofile('util.lua') opt = { batchSize = 64, -- number of samples to produce loadSize = 350, -- resize the loaded image to loadsize maintaining aspect ratio. 0 means don't resize. -1 means scale randomly between [0.5,2] -- see donkey_folder.lua fineSize = 128, -- size of random crops. Only 64 and 128 supported. nBottleneck = 100, -- # of dim for bottleneck of encoder nef = 64, -- # of encoder filters in first conv layer ngf = 64, -- # of gen filters in first conv layer ndf = 64, -- # of discrim filters in first conv layer nc = 3, -- # of channels in input wtl2 = 0, -- 0 means don't use else use with this weight useOverlapPred = 0, -- overlapping edges (1 means yes, 0 means no). 1 means put 10x more L2 weight on unmasked region. nThreads = 4, -- # of data loading threads to use niter = 25, -- # of iter at starting learning rate lr = 0.0002, -- initial learning rate for adam beta1 = 0.5, -- momentum term of adam ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset display = 1, -- display samples while training. 0 = false display_id = 10, -- display window id. display_iter = 50, -- # number of iterations after which display is updated gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X name = 'train1', -- name of the experiment you are running manualSeed = 0, -- 0 means random seed -- Extra Options: conditionAdv = 0, -- 0 means false else true noiseGen = 0, -- 0 means false else true noisetype = 'normal', -- uniform / normal nz = 100, -- # of dim for Z } for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end print(opt) if opt.display == 0 then opt.display = false end if opt.conditionAdv == 0 then opt.conditionAdv = false end if opt.noiseGen == 0 then opt.noiseGen = false end -- set seed if opt.manualSeed == 0 then opt.manualSeed = torch.random(1, 10000) end print("Seed: " .. opt.manualSeed) torch.manualSeed(opt.manualSeed) torch.setnumthreads(1) torch.setdefaulttensortype('torch.FloatTensor') -- create data loader local DataLoader = paths.dofile('data/data.lua') local data = DataLoader.new(opt.nThreads, opt) print("Dataset Size: ", data:size()) --------------------------------------------------------------------------- -- Initialize network variables --------------------------------------------------------------------------- local function weights_init(m) local name = torch.type(m) if name:find('Convolution') then m.weight:normal(0.0, 0.02) m.bias:fill(0) elseif name:find('BatchNormalization') then if m.weight then m.weight:normal(1.0, 0.02) end if m.bias then m.bias:fill(0) end end end local nc = opt.nc local nz = opt.nz local nBottleneck = opt.nBottleneck local ndf = opt.ndf local ngf = opt.ngf local nef = opt.nef local real_label = 1 local fake_label = 0 local SpatialBatchNormalization = nn.SpatialBatchNormalization local SpatialConvolution = nn.SpatialConvolution local SpatialFullConvolution = nn.SpatialFullConvolution --------------------------------------------------------------------------- -- Generator net --------------------------------------------------------------------------- -- Encode Input Context to noise (architecture similar to Discriminator) local netE = nn.Sequential() -- input is (nc) x 128 x 128 netE:add(SpatialConvolution(nc, nef, 4, 4, 2, 2, 1, 1)) netE:add(nn.LeakyReLU(0.2, true)) if opt.fineSize == 128 then -- state size: (nef) x 64 x 64 netE:add(SpatialConvolution(nef, nef, 4, 4, 2, 2, 1, 1)) netE:add(SpatialBatchNormalization(nef)):add(nn.LeakyReLU(0.2, true)) end -- state size: (nef) x 32 x 32 netE:add(SpatialConvolution(nef, nef * 2, 4, 4, 2, 2, 1, 1)) netE:add(SpatialBatchNormalization(nef * 2)):add(nn.LeakyReLU(0.2, true)) -- state size: (nef*2) x 16 x 16 netE:add(SpatialConvolution(nef * 2, nef * 4, 4, 4, 2, 2, 1, 1)) netE:add(SpatialBatchNormalization(nef * 4)):add(nn.LeakyReLU(0.2, true)) -- state size: (nef*4) x 8 x 8 netE:add(SpatialConvolution(nef * 4, nef * 8, 4, 4, 2, 2, 1, 1)) netE:add(SpatialBatchNormalization(nef * 8)):add(nn.LeakyReLU(0.2, true)) -- state size: (nef*8) x 4 x 4 netE:add(SpatialConvolution(nef * 8, nBottleneck, 4, 4)) -- state size: (nBottleneck) x 1 x 1 local netG = nn.Sequential() local nz_size = nBottleneck if opt.noiseGen then local netG_noise = nn.Sequential() -- input is Z: (nz) x 1 x 1, going into a convolution netG_noise:add(SpatialConvolution(nz, nz, 1, 1, 1, 1, 0, 0)) -- state size: (nz) x 1 x 1 local netG_pl = nn.ParallelTable(); netG_pl:add(netE) netG_pl:add(netG_noise) netG:add(netG_pl) netG:add(nn.JoinTable(2)) netG:add(SpatialBatchNormalization(nBottleneck+nz)):add(nn.LeakyReLU(0.2, true)) -- state size: (nBottleneck+nz) x 1 x 1 nz_size = nBottleneck+nz else netG:add(netE) netG:add(SpatialBatchNormalization(nBottleneck)):add(nn.LeakyReLU(0.2, true)) nz_size = nBottleneck end -- Decode noise to generate image -- input is Z: (nz_size) x 1 x 1, going into a convolution netG:add(SpatialFullConvolution(nz_size, ngf * 8, 4, 4)) netG:add(SpatialBatchNormalization(ngf * 8)):add(nn.ReLU(true)) -- state size: (ngf*8) x 4 x 4 netG:add(SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1)) netG:add(SpatialBatchNormalization(ngf * 4)):add(nn.ReLU(true)) -- state size: (ngf*4) x 8 x 8 netG:add(SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1)) netG:add(SpatialBatchNormalization(ngf * 2)):add(nn.ReLU(true)) -- state size: (ngf*2) x 16 x 16 netG:add(SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1)) netG:add(SpatialBatchNormalization(ngf)):add(nn.ReLU(true)) -- state size: (ngf) x 32 x 32 if opt.fineSize == 128 then netG:add(SpatialFullConvolution(ngf, ngf, 4, 4, 2, 2, 1, 1)) netG:add(SpatialBatchNormalization(ngf)):add(nn.ReLU(true)) -- state size: (ngf) x 64 x 64 end netG:add(SpatialFullConvolution(ngf, nc, 4, 4, 2, 2, 1, 1)) netG:add(nn.Tanh()) -- state size: (nc) x 128 x 128 netG:apply(weights_init) --------------------------------------------------------------------------- -- Adversarial discriminator net --------------------------------------------------------------------------- local netD = nn.Sequential() if opt.conditionAdv then print('conditional adv not implemented') exit() local netD_ctx = nn.Sequential() -- input Context: (nc) x 128 x 128, going into a convolution netD_ctx:add(SpatialConvolution(nc, ndf, 5, 5, 2, 2, 2, 2)) -- state size: (ndf) x 64 x 64 local netD_pred = nn.Sequential() -- input pred: (nc) x 64 x 64, going into a convolution netD_pred:add(SpatialConvolution(nc, ndf, 5, 5, 2, 2, 2+32, 2+32)) -- 32: to keep scaling of features same as context -- state size: (ndf) x 64 x 64 local netD_pl = nn.ParallelTable(); netD_pl:add(netD_ctx) netD_pl:add(netD_pred) netD:add(netD_pl) netD:add(nn.JoinTable(2)) netD:add(nn.LeakyReLU(0.2, true)) -- state size: (ndf * 2) x 64 x 64 netD:add(SpatialConvolution(ndf*2, ndf, 4, 4, 2, 2, 1, 1)) netD:add(SpatialBatchNormalization(ndf)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf) x 32 x 32 else -- input is (nc) x 128 x 128, going into a convolution netD:add(SpatialConvolution(nc, ndf, 4, 4, 2, 2, 1, 1)) netD:add(nn.LeakyReLU(0.2, true)) -- state size: (ndf) x 64 x 64 end if opt.fineSize == 128 then netD:add(SpatialConvolution(ndf, ndf, 4, 4, 2, 2, 1, 1)) netD:add(SpatialBatchNormalization(ndf)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf) x 32 x 32 end netD:add(SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1)) netD:add(SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*2) x 16 x 16 netD:add(SpatialConvolution(ndf * 2, ndf * 4, 4, 4, 2, 2, 1, 1)) netD:add(SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*4) x 8 x 8 netD:add(SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1)) netD:add(SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*8) x 4 x 4 netD:add(SpatialConvolution(ndf * 8, 1, 4, 4)) netD:add(nn.Sigmoid()) -- state size: 1 x 1 x 1 netD:add(nn.View(1):setNumInputDims(3)) -- state size: 1 netD:apply(weights_init) --------------------------------------------------------------------------- -- Loss Metrics --------------------------------------------------------------------------- local criterion = nn.BCECriterion() local criterionMSE if opt.wtl2~=0 then criterionMSE = nn.MSECriterion() end --------------------------------------------------------------------------- -- Setup Solver --------------------------------------------------------------------------- print('LR of Gen is ',(opt.wtl2>0 and opt.wtl2<1) and 10 or 1,'times Adv') optimStateG = { learningRate = (opt.wtl2>0 and opt.wtl2<1) and opt.lr*10 or opt.lr, beta1 = opt.beta1, } optimStateD = { learningRate = opt.lr, beta1 = opt.beta1, } --------------------------------------------------------------------------- -- Initialize data variables --------------------------------------------------------------------------- local mask_global = torch.ByteTensor(opt.batchSize, opt.fineSize, opt.fineSize) local input_ctx_vis = torch.Tensor(opt.batchSize, nc, opt.fineSize, opt.fineSize) local input_ctx = torch.Tensor(opt.batchSize, nc, opt.fineSize, opt.fineSize) local input_center = torch.Tensor(opt.batchSize, nc, opt.fineSize, opt.fineSize) local input_real_center if opt.wtl2~=0 then input_real_center = torch.Tensor(opt.batchSize, nc, opt.fineSize, opt.fineSize) end local noise = torch.Tensor(opt.batchSize, nz, 1, 1) local label = torch.Tensor(opt.batchSize) local errD, errG, errG_l2 local epoch_tm = torch.Timer() local tm = torch.Timer() local data_tm = torch.Timer() if pcall(require, 'cudnn') and pcall(require, 'cunn') and opt.gpu>0 then print('Using CUDNN !') end if opt.gpu > 0 then require 'cunn' cutorch.setDevice(opt.gpu) input_ctx_vis = input_ctx_vis:cuda(); input_ctx = input_ctx:cuda(); input_center = input_center:cuda() noise = noise:cuda(); label = label:cuda() netG = util.cudnn(netG); netD = util.cudnn(netD) netD:cuda(); netG:cuda(); criterion:cuda(); if opt.wtl2~=0 then criterionMSE:cuda(); input_real_center = input_real_center:cuda(); end end print('NetG:',netG) print('NetD:',netD) -- Generating random pattern local res = 0.06 -- the lower it is, the more continuous the output will be. 0.01 is too small and 0.1 is too large local density = 0.25 local MAX_SIZE = 10000 local low_pattern = torch.Tensor(res*MAX_SIZE, res*MAX_SIZE):uniform(0,1):mul(255) local pattern = image.scale(low_pattern, MAX_SIZE, MAX_SIZE,'bicubic') low_pattern = nil pattern:div(255); pattern = torch.lt(pattern,density):byte() -- 25% 1s and 75% 0s pattern = pattern:byte() print('...Random pattern generated') local parametersD, gradParametersD = netD:getParameters() local parametersG, gradParametersG = netG:getParameters() if opt.display then disp = require 'display' end noise_vis = noise:clone() if opt.noisetype == 'uniform' then noise_vis:uniform(-1, 1) elseif opt.noisetype == 'normal' then noise_vis:normal(0, 1) end --------------------------------------------------------------------------- -- Define generator and adversary closures --------------------------------------------------------------------------- -- create closure to evaluate f(X) and df/dX of discriminator local fDx = function(x) netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) gradParametersD:zero() -- train with real data_tm:reset(); data_tm:resume() local real_ctx = data:getBatch() real_center = real_ctx -- view input_center:copy(real_center) if opt.wtl2~=0 then input_real_center:copy(real_center) end -- get random mask local mask, wastedIter wastedIter = 0 while true do local x = torch.uniform(1, MAX_SIZE-opt.fineSize) local y = torch.uniform(1, MAX_SIZE-opt.fineSize) mask = pattern[{{y,y+opt.fineSize-1},{x,x+opt.fineSize-1}}] -- view, no allocation local area = mask:sum()*100./(opt.fineSize*opt.fineSize) if area>20 and area<30 then -- want it to be approx 75% 0s and 25% 1s -- print('wasted tries: ',wastedIter) break end wastedIter = wastedIter + 1 end torch.repeatTensor(mask_global,mask,opt.batchSize,1,1) real_ctx[{{},{1},{},{}}][mask_global] = 2*117.0/255.0 - 1.0 real_ctx[{{},{2},{},{}}][mask_global] = 2*104.0/255.0 - 1.0 real_ctx[{{},{3},{},{}}][mask_global] = 2*123.0/255.0 - 1.0 input_ctx:copy(real_ctx) data_tm:stop() label:fill(real_label) local output if opt.conditionAdv then output = netD:forward({input_ctx,input_center}) else output = netD:forward(input_center) end local errD_real = criterion:forward(output, label) local df_do = criterion:backward(output, label) if opt.conditionAdv then netD:backward({input_ctx,input_center}, df_do) else netD:backward(input_center, df_do) end -- train with fake if opt.noisetype == 'uniform' then -- regenerate random noise noise:uniform(-1, 1) elseif opt.noisetype == 'normal' then noise:normal(0, 1) end local fake if opt.noiseGen then fake = netG:forward({input_ctx,noise}) else fake = netG:forward(input_ctx) end input_center:copy(fake) label:fill(fake_label) local output if opt.conditionAdv then output = netD:forward({input_ctx,input_center}) else output = netD:forward(input_center) end local errD_fake = criterion:forward(output, label) local df_do = criterion:backward(output, label) if opt.conditionAdv then netD:backward({input_ctx,input_center}, df_do) else netD:backward(input_center, df_do) end errD = errD_real + errD_fake return errD, gradParametersD end -- create closure to evaluate f(X) and df/dX of generator local fGx = function(x) netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) gradParametersG:zero() --[[ the three lines below were already executed in fDx, so save computation noise:uniform(-1, 1) -- regenerate random noise local fake = netG:forward({input_ctx,noise}) input_center:copy(fake) ]]-- label:fill(real_label) -- fake labels are real for generator cost local output = netD.output -- netD:forward({input_ctx,input_center}) was already executed in fDx, so save computation errG = criterion:forward(output, label) local df_do = criterion:backward(output, label) local df_dg if opt.conditionAdv then df_dg = netD:updateGradInput({input_ctx,input_center}, df_do) df_dg = df_dg[2] -- df_dg[2] because conditional GAN else df_dg = netD:updateGradInput(input_center, df_do) end local errG_total = errG if opt.wtl2~=0 then errG_l2 = criterionMSE:forward(input_center, input_real_center) local df_dg_l2 = criterionMSE:backward(input_center, input_real_center) if opt.useOverlapPred==0 then if (opt.wtl2>0 and opt.wtl2<1) then df_dg:mul(1-opt.wtl2):add(opt.wtl2,df_dg_l2) errG_total = (1-opt.wtl2)*errG + opt.wtl2*errG_l2 else df_dg:add(opt.wtl2,df_dg_l2) errG_total = errG + opt.wtl2*errG_l2 end else local overlapL2Weight = 10 local wtl2Matrix = df_dg_l2:clone():fill(overlapL2Weight*opt.wtl2) for i=1,3 do wtl2Matrix[{{},{i},{},{}}][mask_global] = opt.wtl2 end if (opt.wtl2>0 and opt.wtl2<1) then df_dg:mul(1-opt.wtl2):addcmul(1,wtl2Matrix,df_dg_l2) errG_total = (1-opt.wtl2)*errG + opt.wtl2*errG_l2 else df_dg:addcmul(1,wtl2Matrix,df_dg_l2) errG_total = errG + opt.wtl2*errG_l2 end end end if opt.noiseGen then netG:backward({input_ctx,noise}, df_dg) else netG:backward(input_ctx, df_dg) end return errG_total, gradParametersG end --------------------------------------------------------------------------- -- Train Context Encoder --------------------------------------------------------------------------- for epoch = 1, opt.niter do epoch_tm:reset() local counter = 0 for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do tm:reset() -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) optim.adam(fDx, parametersD, optimStateD) -- (2) Update G network: maximize log(D(G(z))) optim.adam(fGx, parametersG, optimStateG) -- display counter = counter + 1 if counter % opt.display_iter == 0 and opt.display then local real_ctx = data:getBatch() -- disp.image(real_ctx, {win=opt.display_id * 6, title=opt.name}) local mask, wastedIter wastedIter = 0 while true do local x = torch.uniform(1, MAX_SIZE-opt.fineSize) local y = torch.uniform(1, MAX_SIZE-opt.fineSize) mask = pattern[{{y,y+opt.fineSize-1},{x,x+opt.fineSize-1}}] -- view, no allocation local area = mask:sum()*100./(opt.fineSize*opt.fineSize) if area>20 and area<30 then -- want it to be approx 75% 0s and 25% 1s -- print('wasted tries: ',wastedIter) break end wastedIter = wastedIter + 1 end mask=torch.repeatTensor(mask,opt.batchSize,1,1) real_ctx[{{},{1},{},{}}][mask] = 2*117.0/255.0 - 1.0 real_ctx[{{},{2},{},{}}][mask] = 2*104.0/255.0 - 1.0 real_ctx[{{},{3},{},{}}][mask] = 2*123.0/255.0 - 1.0 input_ctx_vis:copy(real_ctx) local fake if opt.noiseGen then fake = netG:forward({input_ctx_vis,noise_vis}) else fake = netG:forward(input_ctx_vis) end disp.image(fake, {win=opt.display_id, title=opt.name}) real_ctx[{{},{1},{},{}}][mask] = 1.0 real_ctx[{{},{2},{},{}}][mask] = 1.0 real_ctx[{{},{3},{},{}}][mask] = 1.0 disp.image(real_ctx, {win=opt.display_id * 3, title=opt.name}) end -- logging if ((i-1) / opt.batchSize) % 1 == 0 then print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' .. ' Err_G_L2: %.4f Err_G: %.4f Err_D: %.4f'):format( epoch, ((i-1) / opt.batchSize), math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize), tm:time().real, data_tm:time().real, errG_l2 or -1, errG and errG or -1, errD and errD or -1)) end end paths.mkdir('checkpoints') parametersD, gradParametersD = nil, nil -- nil them to avoid spiking memory parametersG, gradParametersG = nil, nil if epoch % 20 == 0 then util.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_G.t7', netG, opt.gpu) util.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_D.t7', netD, opt.gpu) end parametersD, gradParametersD = netD:getParameters() -- reflatten the params and get them parametersG, gradParametersG = netG:getParameters() print(('End of epoch %d / %d \t Time Taken: %.3f'):format( epoch, opt.niter, epoch_tm:time().real)) end
We use cookies to provide and improve our services. By using our site, you consent to our Cookies Policy. Accept Learn more