import io
import os
import zipfile
from typing import Any, Dict, List, Optional, Tuple, cast
import streamlit as st
from deta import Deta
from deta import _Base as DetaBase
from deta import _Drive as DetaDrive
from loguru import logger
from typing_extensions import Literal
from . import field_def
from .const import DataDefsCollectionDict, DataSample
TakeVarsFrom = Literal["st_secrets", "env"]
[docs]def connect_to_db(
deta_key_secret: str,
base_name_env_var: str,
take_vars_from: TakeVarsFrom = "st_secrets",
drive_name_env_var: Optional[str] = None,
) -> Tuple[Deta, DetaBase, Optional[DetaDrive]]:
var_taker = st.secrets if take_vars_from == "st_secrets" else os.environ
deta = Deta(var_taker[deta_key_secret])
base = deta.Base(var_taker[base_name_env_var])
drive: Optional[DetaDrive] = None
if drive_name_env_var:
drive = deta.Drive(var_taker[drive_name_env_var])
return deta, base, drive
[docs]def download_zipped_dir(drive: DetaDrive, zip_file: str = "data.zip", local_dir: str = "./data") -> None:
# NOTE: Will only download and extract if the local directory does not exist.
directory = os.path.realpath(local_dir)
if not os.path.exists(directory):
logger.info(f"Local directory {local_dir} does not exist, creating")
os.makedirs(directory)
# Get zip file from Deta Drive.
logger.info(f"Downloading {zip_file} from Deta Drive")
# For debug:
# logger.info(drive.list())
file = drive.get(zip_file)
if file is None:
raise RuntimeError(f"File {zip_file} not found on Deta Drive")
logger.info(f"Unzipping {zip_file} to {local_dir}")
bytes_ = file.read()
with zipfile.ZipFile(io.BytesIO(bytes_), "r") as zip_ref:
zip_ref.extractall(directory)
logger.info("Downloading and extracting zip file finished")
[docs]def get_all_sample_keys(db: DetaBase) -> List[str]:
# TODO: This is inefficient. Needs to be improved.
all_data = db.fetch()
# if all_data.count == 0:
# raise RuntimeError("No data found")
if all_data.last is not None:
raise RuntimeError("Too many data rows. Supported max rows is 1000.")
return [example["key"] for example in all_data.items]
def _sort_fields(sort_key: List[str], fields: Dict[str, Dict]) -> Dict[str, Dict]:
# Sort the fields in field_defs order (the fields in the DB are in random order).
sorted_fields: Dict[str, Any] = dict()
for key in sort_key:
sorted_fields[key] = fields[key]
return sorted_fields
def _sort_fields_in_array(sort_key: List[str], array_of_fields: List[Dict[str, Dict]]) -> List[Dict[str, Dict]]:
sorted_array_of_fields: List[Dict[str, Dict]] = []
for fields in array_of_fields:
sorted_array_of_fields.append(_sort_fields(sort_key=sort_key, fields=fields))
return sorted_array_of_fields
[docs]def get_sample(key: str, db: DetaBase, field_defs: "field_def.FieldDefsCollection") -> DataSample:
raw_data = cast(DataDefsCollectionDict, db.get(key))
static = _sort_fields(sort_key=list(field_defs.static.keys()), fields=raw_data["static"])
temporal = _sort_fields_in_array(sort_key=list(field_defs.temporal.keys()), array_of_fields=raw_data["temporal"])
event = _sort_fields_in_array(sort_key=list(field_defs.event.keys()), array_of_fields=raw_data["event"])
static = field_def.process_db_to_input(field_defs=field_defs.static, data=static)
temporal = [field_def.process_db_to_input(field_defs=field_defs.temporal, data=x) for x in temporal]
event = [field_def.process_db_to_input(field_defs=field_defs.event, data=x) for x in event]
return DataSample(static=static, temporal=temporal, event=event)
[docs]def add_empty_sample(db: DetaBase, key: str, field_defs: "field_def.FieldDefsCollection", current_timestep: Any):
# Get non-computed defaults.
static = field_def.get_default(field_defs=field_defs.static, modality="static") if field_defs.static else dict()
temporal_0 = (
field_def.get_default(field_defs=field_defs.temporal, modality="temporal", data_sample="first_step")
if field_defs.temporal
else dict()
)
event_0 = field_def.get_default(field_defs=field_defs.event, modality="event") if field_defs.event else dict()
data_sample = DataSample(
static=static,
temporal=[temporal_0] if temporal_0 else [],
event=[event_0] if event_0 else [],
)
# Get computed defaults.
if field_defs.static:
data_sample.static = field_def.get_default_computed(
field_defs=field_defs.static,
modality="static",
data_sample_before_computation=data_sample,
current_timestep=current_timestep,
)
if field_defs.temporal:
temporal_0 = field_def.get_default_computed(
field_defs=field_defs.temporal,
modality="temporal",
data_sample_before_computation=data_sample,
current_timestep=current_timestep,
)
data_sample.temporal = [temporal_0]
if field_defs.event:
# TODO: Event is not yet properly handled.
event_0 = field_def.get_default_computed(
field_defs=field_defs.event,
modality="event",
data_sample_before_computation=data_sample,
current_timestep=current_timestep,
)
data_sample.event = [event_0]
static = field_def.process_input_to_db(field_defs=field_defs.static, data=data_sample.static)
temporal = [field_def.process_input_to_db(field_defs=field_defs.temporal, data=temporal_0)]
event = [field_def.process_input_to_db(field_defs=field_defs.event, data=event_0)]
data_sample_for_db = dict(DataSample(static=static, temporal=temporal, event=event))
logger.info(f"Adding new sample to db.\nkey: {key}\ndata:\n{data_sample_for_db}")
db.put(data_sample_for_db, key=key)
[docs]def delete_sample(db: DetaBase, key: str):
logger.info(f"Deleting sample from db.\nkey: {key}")
db.delete(key=key)
[docs]def update_sample(db: DetaBase, key: str, data_sample: DataSample, field_defs: "field_def.FieldDefsCollection"):
static = field_def.process_input_to_db(field_defs=field_defs.static, data=data_sample.static)
temporal = [field_def.process_input_to_db(field_defs=field_defs.temporal, data=x) for x in data_sample.temporal]
event = [field_def.process_input_to_db(field_defs=field_defs.event, data=x) for x in data_sample.event]
data_sample_processed = dict(DataSample(static=static, temporal=temporal, event=event))
logger.info(f"Updating sample sample in db.\nkey: {key}\ndata:\n{data_sample_processed}")
db.put(dict(data_sample_processed), key=key)