Step 2: Implement the File Server upload and list_files methods¶
Let’s dive in on how to implement these methods.
Implement the upload method¶
In this method, we are creating a stream between the uploaded file and the uploaded file stored on the file server disk.
Once the file is uploaded, we are putting the file into the Drive
, so it becomes persistent and accessible to all Components.
class FileServer(LightningWork):
def upload_file(self, file):
"""Upload a file while tracking its progress."""
# 1: Track metadata about the file
filename = file.filename
uploaded_file = self.get_random_filename()
meta_file = uploaded_file + ".meta"
self.uploaded_files[filename] = {
"progress": (0, None), "done": False
}
# 2: Create a stream and write bytes of
# the file to the disk under `uploaded_file` path.
with open(self.get_filepath(uploaded_file), "wb") as out_file:
content = file.read(self.chunk_size)
while content:
# 2.1 Write the file bytes
size = out_file.write(content)
# 2.2 Update the progress metadata
self.uploaded_files[filename]["progress"] = (
self.uploaded_files[filename]["progress"][0] + size,
None,
)
# 4: Read next chunk of data
content = file.read(self.chunk_size)
# 3: Update metadata that the file has been uploaded.
full_size = self.uploaded_files[filename]["progress"][0]
self.drive.put(self.get_filepath(uploaded_file))
self.uploaded_files[filename] = {
"progress": (full_size, full_size),
"done": True,
"uploaded_file": uploaded_file,
}
# 4: Write down the metadata about the file to the disk
meta = {
"original_path": filename,
"display_name": os.path.splitext(filename)[0],
"size": full_size,
"drive_path": uploaded_file,
}
with open(self.get_filepath(meta_file), "w") as f:
json.dump(meta, f)
# 5: Put the file to the drive.
# It means other components can access get or list them.
self.drive.put(self.get_filepath(meta_file))
Implement the fist_files method¶
First, in this method, we get the file in the file server filesystem, if available in the Drive. Once done, we list the the files under the provided paths and return the results.
class FileServer(LightningWork):
return meta
def list_files(self, file_path: str):
# 1: Get the local file path of the file server.
file_path = self.get_filepath(file_path)
# 2: If the file exists in the drive, transfer it locally.
if not os.path.exists(file_path):
self.drive.get(file_path)
if os.path.isdir(file_path):
result = set()
for _, _, f in os.walk(file_path):
for file in f:
if not file.endswith(".meta"):
for filename, meta in self.uploaded_files.items():
if meta["uploaded_file"] == file:
result.add(filename)
return {"asset_names": [v for v in result]}
# 3: If the filepath is a tar or zip file, list their contents
if zipfile.is_zipfile(file_path):
with zipfile.ZipFile(file_path, "r") as zf:
result = zf.namelist()
elif tarfile.is_tarfile(file_path):
with tarfile.TarFile(file_path, "r") as tf:
result = tf.getnames()
else:
raise ValueError("Cannot open archive file!")
# 4: Returns the matching files.
Implement utilities¶
class FileServer(LightningWork):
def get_filepath(self, path: str) -> str:
"""Returns file path stored on the file server."""
return os.path.join(self.base_dir, path)
def get_random_filename(self) -> str:
"""Returns a random hash for the file name."""