aboutsummaryrefslogtreecommitdiff
path: root/libs/storage.py
diff options
context:
space:
mode:
Diffstat (limited to 'libs/storage.py')
-rw-r--r--libs/storage.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/libs/storage.py b/libs/storage.py
new file mode 100644
index 0000000..6220cfc
--- /dev/null
+++ b/libs/storage.py
@@ -0,0 +1,75 @@
1from pathlib import Path
2from typing import Any, Mapping
3
4from aiofiles import open as open
5from aiogram.fsm.state import State
6from aiogram.fsm.storage.base import (
7 BaseStorage,
8 DefaultKeyBuilder,
9 KeyBuilder,
10 StateType,
11 StorageKey,
12)
13from pydantic import TypeAdapter
14from pydantic.main import BaseModel
15
16
17class Record(BaseModel):
18 data: dict[str, Any] = {}
19 state: str | None = None
20
21
22class JsonStorage(BaseStorage):
23 file_path: Path
24 records: dict[str, Record]
25 records_adapter: TypeAdapter
26 key_builder: KeyBuilder
27
28 def __init__(self, file_path: Path, key_builder: KeyBuilder | None = None) -> None:
29 self.file_path = file_path
30 self.records = {}
31 self.records_adapter = TypeAdapter(dict[str, Record])
32 self.key_builder = DefaultKeyBuilder() if key_builder is None else key_builder
33
34 async def read(self) -> None:
35 async with open(self.file_path, "rb") as file:
36 json = await file.read()
37 self.records = self.records_adapter.validate_json(json)
38
39 async def flush(self) -> None:
40 async with open(self.file_path, "wb") as file:
41 json = self.records_adapter.dump_json(self.records)
42 await file.write(json)
43
44 async def get_record(self, key: StorageKey) -> Record:
45 await self.read()
46 record_key = self.key_builder.build(key)
47 if record_key not in self.records:
48 self.records[record_key] = Record()
49 return self.records[record_key]
50
51 async def set_state(self, key: StorageKey, state: StateType = None) -> None:
52 record = await self.get_record(key)
53 record.state = state.state if isinstance(state, State) else state
54 await self.flush()
55
56 async def get_state(self, key: StorageKey) -> str | None:
57 record = await self.get_record(key)
58 return record.state
59
60 async def set_data(self, key: StorageKey, data: Mapping[str, Any]) -> None:
61 if not isinstance(data, dict):
62 raise TypeError(
63 f"Data must be a dict or dict-like object, got {type(data).__name__}",
64 data,
65 )
66 record = await self.get_record(key)
67 record.data = data.copy()
68 await self.flush()
69
70 async def get_data(self, key: StorageKey) -> dict[str, Any]:
71 record = await self.get_record(key)
72 return record.data
73
74 async def close(self) -> None:
75 await self.flush()