summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py2
-rw-r--r--train_lora.py2
-rw-r--r--train_ti.py2
-rw-r--r--training/functional.py53
-rw-r--r--training/strategy/dreambooth.py6
5 files changed, 32 insertions, 33 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 431ff3d..280cf77 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -439,7 +439,6 @@ def main():
439 accelerator = Accelerator( 439 accelerator = Accelerator(
440 log_with=LoggerType.TENSORBOARD, 440 log_with=LoggerType.TENSORBOARD,
441 logging_dir=f"{output_dir}", 441 logging_dir=f"{output_dir}",
442 gradient_accumulation_steps=args.gradient_accumulation_steps,
443 mixed_precision=args.mixed_precision 442 mixed_precision=args.mixed_precision
444 ) 443 )
445 444
@@ -590,6 +589,7 @@ def main():
590 lr_scheduler=lr_scheduler, 589 lr_scheduler=lr_scheduler,
591 prepare_unet=True, 590 prepare_unet=True,
592 num_train_epochs=args.num_train_epochs, 591 num_train_epochs=args.num_train_epochs,
592 gradient_accumulation_steps=args.gradient_accumulation_steps,
593 sample_frequency=args.sample_frequency, 593 sample_frequency=args.sample_frequency,
594 # -- 594 # --
595 tokenizer=tokenizer, 595 tokenizer=tokenizer,
diff --git a/train_lora.py b/train_lora.py
index a06591d..d7c2de0 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -399,7 +399,6 @@ def main():
399 accelerator = Accelerator( 399 accelerator = Accelerator(
400 log_with=LoggerType.TENSORBOARD, 400 log_with=LoggerType.TENSORBOARD,
401 logging_dir=f"{output_dir}", 401 logging_dir=f"{output_dir}",
402 gradient_accumulation_steps=args.gradient_accumulation_steps,
403 mixed_precision=args.mixed_precision 402 mixed_precision=args.mixed_precision
404 ) 403 )
405 404
@@ -561,6 +560,7 @@ def main():
561 optimizer=optimizer, 560 optimizer=optimizer,
562 lr_scheduler=lr_scheduler, 561 lr_scheduler=lr_scheduler,
563 num_train_epochs=args.num_train_epochs, 562 num_train_epochs=args.num_train_epochs,
563 gradient_accumulation_steps=args.gradient_accumulation_steps,
564 sample_frequency=args.sample_frequency, 564 sample_frequency=args.sample_frequency,
565 # -- 565 # --
566 tokenizer=tokenizer, 566 tokenizer=tokenizer,
diff --git a/train_ti.py b/train_ti.py
index 6dc07dd..68783ea 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -518,7 +518,6 @@ def main():
518 accelerator = Accelerator( 518 accelerator = Accelerator(
519 log_with=LoggerType.TENSORBOARD, 519 log_with=LoggerType.TENSORBOARD,
520 logging_dir=f"{output_dir}", 520 logging_dir=f"{output_dir}",
521 gradient_accumulation_steps=args.gradient_accumulation_steps,
522 mixed_precision=args.mixed_precision 521 mixed_precision=args.mixed_precision
523 ) 522 )
524 523
@@ -611,6 +610,7 @@ def main():
611 low_freq_noise=0, 610 low_freq_noise=0,
612 strategy=textual_inversion_strategy, 611 strategy=textual_inversion_strategy,
613 num_train_epochs=args.num_train_epochs, 612 num_train_epochs=args.num_train_epochs,
613 gradient_accumulation_steps=args.gradient_accumulation_steps,
614 sample_frequency=args.sample_frequency, 614 sample_frequency=args.sample_frequency,
615 checkpoint_frequency=args.checkpoint_frequency, 615 checkpoint_frequency=args.checkpoint_frequency,
616 milestone_checkpoints=not args.no_milestone_checkpoints, 616 milestone_checkpoints=not args.no_milestone_checkpoints,
diff --git a/training/functional.py b/training/functional.py
index 739d055..3f5fa7e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -365,15 +365,17 @@ def train_loop(
365 milestone_checkpoints: bool = True, 365 milestone_checkpoints: bool = True,
366 global_step_offset: int = 0, 366 global_step_offset: int = 0,
367 num_epochs: int = 100, 367 num_epochs: int = 100,
368 gradient_accumulation_steps: int = 1,
368 callbacks: TrainingCallbacks = TrainingCallbacks(), 369 callbacks: TrainingCallbacks = TrainingCallbacks(),
369): 370):
370 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 371 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
371 num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 372 num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0
372 373
373 num_training_steps = num_training_steps_per_epoch * num_epochs 374 num_training_steps = num_training_steps_per_epoch * num_epochs
374 num_val_steps = num_val_steps_per_epoch * num_epochs 375 num_val_steps = num_val_steps_per_epoch * num_epochs
375 376
376 global_step = 0 377 global_step = 0
378 train_step = 0
377 379
378 avg_loss = AverageMeter() 380 avg_loss = AverageMeter()
379 avg_acc = AverageMeter() 381 avg_acc = AverageMeter()
@@ -434,44 +436,45 @@ def train_loop(
434 436
435 with on_train(epoch): 437 with on_train(epoch):
436 for step, batch in enumerate(train_dataloader): 438 for step, batch in enumerate(train_dataloader):
437 with accelerator.accumulate(model): 439 loss, acc, bsz = loss_step(step, batch)
438 loss, acc, bsz = loss_step(step, batch) 440 loss /= gradient_accumulation_steps
439 441
440 accelerator.backward(loss) 442 avg_loss.update(loss.detach_(), bsz)
443 avg_acc.update(acc.detach_(), bsz)
441 444
445 accelerator.backward(loss)
446
447 logs = {
448 "train/loss": avg_loss.avg.item(),
449 "train/acc": avg_acc.avg.item(),
450 "train/cur_loss": loss.item(),
451 "train/cur_acc": acc.item(),
452 "lr": lr_scheduler.get_last_lr()[0],
453 }
454 logs.update(on_log())
455
456 local_progress_bar.set_postfix(**logs)
457
458 train_step += 1
459
460 if train_step % gradient_accumulation_steps == 0:
442 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) 461 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
443 462
444 optimizer.step() 463 optimizer.step()
445 lr_scheduler.step() 464 lr_scheduler.step()
446 optimizer.zero_grad(set_to_none=True) 465 optimizer.zero_grad(set_to_none=True)
447 466
448 avg_loss.update(loss.detach_(), bsz)
449 avg_acc.update(acc.detach_(), bsz)
450
451 # Checks if the accelerator has performed an optimization step behind the scenes
452 if accelerator.sync_gradients:
453 on_after_optimize(lr_scheduler.get_last_lr()[0]) 467 on_after_optimize(lr_scheduler.get_last_lr()[0])
454 468
455 local_progress_bar.update(1) 469 local_progress_bar.update(1)
456 global_progress_bar.update(1) 470 global_progress_bar.update(1)
457 471
458 global_step += 1 472 accelerator.log(logs, step=global_step)
459 473
460 logs = { 474 global_step += 1
461 "train/loss": avg_loss.avg.item(),
462 "train/acc": avg_acc.avg.item(),
463 "train/cur_loss": loss.item(),
464 "train/cur_acc": acc.item(),
465 "lr": lr_scheduler.get_last_lr()[0],
466 }
467 logs.update(on_log())
468
469 accelerator.log(logs, step=global_step)
470
471 local_progress_bar.set_postfix(**logs)
472 475
473 if global_step >= num_training_steps: 476 if global_step >= num_training_steps:
474 break 477 break
475 478
476 accelerator.wait_for_everyone() 479 accelerator.wait_for_everyone()
477 480
@@ -571,6 +574,7 @@ def train(
571 strategy: TrainingStrategy, 574 strategy: TrainingStrategy,
572 no_val: bool = False, 575 no_val: bool = False,
573 num_train_epochs: int = 100, 576 num_train_epochs: int = 100,
577 gradient_accumulation_steps: int = 1,
574 sample_frequency: int = 20, 578 sample_frequency: int = 20,
575 checkpoint_frequency: int = 50, 579 checkpoint_frequency: int = 50,
576 milestone_checkpoints: bool = True, 580 milestone_checkpoints: bool = True,
@@ -631,6 +635,7 @@ def train(
631 milestone_checkpoints=milestone_checkpoints, 635 milestone_checkpoints=milestone_checkpoints,
632 global_step_offset=global_step_offset, 636 global_step_offset=global_step_offset,
633 num_epochs=num_train_epochs, 637 num_epochs=num_train_epochs,
638 gradient_accumulation_steps=gradient_accumulation_steps,
634 callbacks=callbacks, 639 callbacks=callbacks,
635 ) 640 )
636 641
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index d697554..fcf5c0d 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -41,12 +41,6 @@ def dreambooth_strategy_callbacks(
41 sample_guidance_scale: float = 7.5, 41 sample_guidance_scale: float = 7.5,
42 sample_image_size: Optional[int] = None, 42 sample_image_size: Optional[int] = None,
43): 43):
44 if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
45 raise ValueError(
46 "Gradient accumulation is not supported when training the text encoder in distributed training. "
47 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
48 )
49
50 sample_output_dir.mkdir(parents=True, exist_ok=True) 44 sample_output_dir.mkdir(parents=True, exist_ok=True)
51 checkpoint_output_dir.mkdir(parents=True, exist_ok=True) 45 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
52 46