136 lines
4.4 KiB
Python
136 lines
4.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
Streaming module API that should be implemented by all Streaming components,
|
|
"""
|
|
|
|
from contextlib import contextmanager
|
|
import typing as tp
|
|
from torch import nn
|
|
import torch
|
|
|
|
|
|
State = tp.Dict[str, torch.Tensor]
|
|
|
|
|
|
class StreamingModule(nn.Module):
|
|
"""Common API for streaming components.
|
|
|
|
Each streaming component has a streaming state, which is just a dict[str, Tensor].
|
|
By convention, the first dim of each tensor must be the batch size.
|
|
Don't use dots in the key names, as this would clash with submodules
|
|
(like in state_dict).
|
|
|
|
If `self._is_streaming` is True, the component should use and remember
|
|
the proper state inside `self._streaming_state`.
|
|
|
|
To set a streaming component in streaming state, use
|
|
|
|
with module.streaming():
|
|
...
|
|
|
|
This will automatically reset the streaming state when exiting the context manager.
|
|
This also automatically propagates to all streaming children module.
|
|
|
|
Some module might also implement the `StreamingModule.flush` method, although
|
|
this one is trickier, as all parents module must be StreamingModule and implement
|
|
it as well for it to work properly. See `StreamingSequential` after.
|
|
"""
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._streaming_state: State = {}
|
|
self._is_streaming = False
|
|
|
|
def _apply_named_streaming(self, fn: tp.Any):
|
|
for name, module in self.named_modules():
|
|
if isinstance(module, StreamingModule):
|
|
fn(name, module)
|
|
|
|
def _set_streaming(self, streaming: bool):
|
|
def _set_streaming(name, module):
|
|
module._is_streaming = streaming
|
|
self._apply_named_streaming(_set_streaming)
|
|
|
|
@contextmanager
|
|
def streaming(self):
|
|
"""Context manager to enter streaming mode. Reset streaming state on exit.
|
|
"""
|
|
self._set_streaming(True)
|
|
try:
|
|
yield
|
|
finally:
|
|
self._set_streaming(False)
|
|
self.reset_streaming()
|
|
|
|
def reset_streaming(self):
|
|
"""Reset the streaming state.
|
|
"""
|
|
def _reset(name: str, module: StreamingModule):
|
|
module._streaming_state.clear()
|
|
|
|
self._apply_named_streaming(_reset)
|
|
|
|
def get_streaming_state(self) -> State:
|
|
"""Return the streaming state, including that of sub-modules.
|
|
"""
|
|
state: State = {}
|
|
|
|
def _add(name: str, module: StreamingModule):
|
|
if name:
|
|
name += "."
|
|
for key, value in module._streaming_state.items():
|
|
state[name + key] = value
|
|
|
|
self._apply_named_streaming(_add)
|
|
return state
|
|
|
|
def set_streaming_state(self, state: State):
|
|
"""Set the streaming state, including that of sub-modules.
|
|
"""
|
|
state = dict(state)
|
|
|
|
def _set(name: str, module: StreamingModule):
|
|
if name:
|
|
name += "."
|
|
module._streaming_state.clear()
|
|
for key, value in list(state.items()):
|
|
# complexity is not ideal here, but probably fine.
|
|
if key.startswith(name):
|
|
local_key = key[len(name):]
|
|
if '.' not in local_key:
|
|
module._streaming_state[local_key] = value
|
|
del state[key]
|
|
|
|
self._apply_named_streaming(_set)
|
|
assert len(state) == 0, list(state.keys())
|
|
|
|
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
|
"""Flush any remaining outputs that were waiting for completion.
|
|
Typically, for convolutions, this will add the final padding
|
|
and process the last buffer.
|
|
|
|
This should take an optional argument `x`, which will be provided
|
|
if a module before this one in the streaming pipeline has already
|
|
spitted out a flushed out buffer.
|
|
"""
|
|
if x is None:
|
|
return None
|
|
else:
|
|
return self(x)
|
|
|
|
|
|
class StreamingSequential(StreamingModule, nn.Sequential):
|
|
"""A streaming compatible alternative of `nn.Sequential`.
|
|
"""
|
|
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
|
for module in self:
|
|
if isinstance(module, StreamingModule):
|
|
x = module.flush(x)
|
|
elif x is not None:
|
|
x = module(x)
|
|
return x
|