Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions orbit/actions/export_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import re

from typing import Callable, Optional
from typing import Any, Callable, List, Optional

import tensorflow as tf, tf_keras

Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
self._subdirectory = subdirectory or ''

@property
def managed_files(self):
def managed_files(self) -> List[str]:
"""Returns all files managed by this instance, in sorted order.

Returns:
Expand All @@ -116,7 +116,7 @@ def managed_files(self):
files.append(file)
return files

def clean_up(self):
def clean_up(self) -> None:
"""Cleans up old files matching `{base_name}-*`.

The most recent `max_to_keep` files are preserved.
Expand All @@ -141,7 +141,7 @@ class ExportSavedModel:
def __init__(self,
model: tf.Module,
file_manager: ExportFileManager,
signatures,
signatures: Any,
options: Optional[tf.saved_model.SaveOptions] = None):
"""Initializes the instance.

Expand All @@ -157,7 +157,7 @@ def __init__(self,
self.signatures = signatures
self.options = options

def __call__(self, _):
def __call__(self, _) -> None:
"""Exports the SavedModel."""
export_dir = self.file_manager.next_name()
tf.saved_model.save(self.model, export_dir, self.signatures, self.options)
Expand Down
8 changes: 4 additions & 4 deletions orbit/standard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""

def __init__(self,
train_dataset,
train_dataset: Any,
options: Optional[StandardTrainerOptions] = None):
"""Initializes the `StandardTrainer` instance.

Expand Down Expand Up @@ -146,7 +146,7 @@ def train(self, num_steps: tf.Tensor) -> Optional[runner.Output]:
self._train_loop_fn(self._train_iter, num_steps)
return self.train_loop_end()

def train_loop_begin(self):
def train_loop_begin(self) -> None:
"""Called once at the beginning of the training loop.

This method is always called in eager mode, and is a good place to reset
Expand All @@ -157,7 +157,7 @@ def train_loop_begin(self):
pass

@abc.abstractmethod
def train_step(self, iterator):
def train_step(self, iterator: Any) :
"""Implements one step of training.

What a "step" consists of is up to the implementer. When using distribution
Expand Down Expand Up @@ -259,7 +259,7 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""

def __init__(self,
eval_dataset,
eval_dataset: Any,
options: Optional[StandardEvaluatorOptions] = None):
"""Initializes the `StandardEvaluator` instance.

Expand Down
10 changes: 8 additions & 2 deletions orbit/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import inspect

from typing import Any, Callable, Optional, Union
import tensorflow as tf, tf_keras


Expand Down Expand Up @@ -44,7 +45,12 @@ def create_global_step() -> tf.Variable:
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)


def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
def make_distributed_dataset(
strategy: Optional[tf.distribute.Strategy],
dataset_or_fn: Union[tf.data.Dataset, Callable],
*args: Any,
**kwargs: Any
) -> tf.distribute.DistributedDataset:
"""A utility function to help create a `tf.distribute.DistributedDataset`.

Args:
Expand Down Expand Up @@ -90,7 +96,7 @@ def dataset_fn(input_context):
return strategy.distribute_datasets_from_function(dataset_fn, input_options)


def get_value(x):
def get_value(x) -> Any:
"""Returns input values, converting any TensorFlow values to NumPy values.

Args:
Expand Down
8 changes: 4 additions & 4 deletions orbit/utils/epoch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, epoch_steps: int, global_step: tf.Variable):
self._epoch_start_step = None
self._in_epoch = False

def epoch_begin(self):
def epoch_begin(self) -> bool:
"""Returns whether a new epoch should begin."""
if self._in_epoch:
return False
Expand All @@ -43,7 +43,7 @@ def epoch_begin(self):
self._in_epoch = True
return True

def epoch_end(self):
def epoch_end(self) -> bool:
"""Returns whether the current epoch should end."""
if not self._in_epoch:
raise ValueError("`epoch_end` can only be called inside an epoch.")
Expand All @@ -56,10 +56,10 @@ def epoch_end(self):
return False

@property
def batch_index(self):
def batch_index(self) -> int:
"""Index of the next batch within the current epoch."""
return self._global_step.numpy() - self._epoch_start_step

@property
def current_epoch(self):
def current_epoch(self) -> int:
return self._current_epoch
14 changes: 10 additions & 4 deletions orbit/utils/loop_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
from absl import logging
from orbit.utils import tpu_summaries

from typing import Any, Callable, Optional, Iterator

import tensorflow as tf, tf_keras

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.tpu import embedding_context_utils as ecu
# pylint: enable=g-direct-tensorflow-import


def create_loop_fn(step_fn):
def create_loop_fn(step_fn: Callable[[Any], Any]) -> Callable:
"""Creates a loop function driven by a Python `while` loop.

Args:
Expand All @@ -40,7 +42,11 @@ def create_loop_fn(step_fn):
additional details.
"""

def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
def loop_fn(iterator: Any,
num_steps: int,
state=None,
reduce_fn: Optional[Callable[[Any, Any], Any]] = None
) -> Any:
"""Makes `num_steps` calls to `step_fn(iterator)`.

Additionally, state may be accumulated across iterations of the loop.
Expand Down Expand Up @@ -89,7 +95,7 @@ def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
return loop_fn


def create_tf_while_loop_fn(step_fn):
def create_tf_while_loop_fn(step_fn: Callable[[Any], Any]) -> Callable:
"""Creates a loop function compatible with TF's AutoGraph loop conversion.

Args:
Expand All @@ -103,7 +109,7 @@ def create_tf_while_loop_fn(step_fn):
additional details.
"""

def loop_fn(iterator, num_steps):
def loop_fn(iterator: Any, num_steps: tf.Tensor):
"""Makes `num_steps` calls to `step_fn(iterator)`.

Args:
Expand Down