nn.Module or lightning module when constructing a model from multiple classes?

I’m looking for general best practices when constructing something like a VAE which uses pre-defined Encoders and Decoders. For example, this may look something like this:

class MyEncoder(nn.Module):
    def __init__(self, ...):
        ...
class MyDecoder(nn.Module):
    def __init__(self, ...):
        ...

class VAE(pl.LightningModule):
    def __init__(self, ...):
        self.encoder = MyEncoder()
        self.decoder = MyDecoder()

Is it generally best practice for the classes MyEncoder and MyDecoder to inherit from nn.Module, or should they also inherit from pl.LightningModule?

I think it is best practice to have the VAE Module inherit from nn.Module too and have a VAELitModule as a wrapper for training.

Related: rework docs to encourage the best practice of having a separate nn.Module · Issue #11834 · Lightning-AI/lightning · GitHub

1 Like