33 次代碼提交 68c9c86b24 ... 8dda1a83b9

作者 SHA1 備註 提交日期
  Maarten van den Berg 8dda1a83b9 Merge branch 'aardbei-integration' 4 年之前
  Maarten van den Berg 79c2a4c05c Replace simpleaudio with QSoundEffects 4 年之前
  Maarten van den Berg 0965e3c94d Add renaming to piket-cli 4 年之前
  Maarten van den Berg f5f61fe197 Pin Python 3.7.5 4 年之前
  Maarten van den Berg 011cdf175f piket-cli: Allow setting Aardbei options via environment 4 年之前
  Maarten van den Berg 99d14fcd6e routes.aardbei: Fix crash in aardbei_apply 4 年之前
  Maarten van den Berg 7a1a003531 aardbei_sync: Actually use passed endpoint, log connection errors 4 年之前
  Maarten van den Berg b2f35d181c Update set_active to new client model 5 年之前
  Maarten van den Berg 7d7611916b Make isort play nice with Black 5 年之前
  Maarten van den Berg 7d600aa868 Add activities command to cli 5 年之前
  Maarten van den Berg abd3266d2a Fix bug in mismatch between IDs 5 年之前
  Maarten van den Berg 61fa8c247c Fix crash in /aardbei/diff_people 5 年之前
  Maarten van den Berg b220ea7c32 Fix Mypy errors 5 年之前
  Maarten van den Berg b6ce1b40be Export piket-cli entry point 5 年之前
  Maarten van den Berg f75bafe9f7 Enable unique constraints for unique fields 5 年之前
  Maarten van den Berg d398582d16 Add "people create" 5 年之前
  Maarten van den Berg 83a9386526 Add "people list" 5 年之前
  Maarten van den Berg eb636e6ebe Add subcommands for consumption types to cli 5 年之前
  Maarten van den Berg 4d67672ed1 Expose active, add patch route for ctypes 5 年之前
  Maarten van den Berg 1cb051b5e1 Add settlements create command 5 年之前
  Maarten van den Berg b67e022de5 Add command to show settlement 5 年之前
  Maarten van den Berg a92f2344c4 per_person_counts, NetworkError in client Settlement 5 年之前
  Maarten van den Berg 6d587be65b Begin work on new CLI 5 年之前
  Maarten van den Berg 98fdf2351f Add isort config 5 年之前
  Maarten van den Berg 47352b1bcc Add matching local state with an activity 5 年之前
  Maarten van den Berg cc9edc28bb Make aardbei_sync actually work 5 年之前
  Maarten van den Berg e80d70e999 Split up server code into modules 5 年之前
  Maarten van den Berg 1a612920f0 Add isort to dev dependencies 5 年之前
  Maarten van den Berg ab3a0b921c client.model: Fix type annotations 5 年之前
  Maarten van den Berg 9405cec96a Add new name fields to client 5 年之前
  Maarten van den Berg 29302bc68c Add script to sync people from Aardbei 5 年之前
  Maarten van den Berg aafd7fb250 Add mypy to dev dependencies 5 年之前
  Maarten van den Berg d9abfaa30d server: Add Aardbei fields to Person 5 年之前

+ 3 - 0
.gitignore

@@ -7,3 +7,6 @@
7 7
 # Python egg metadata, regenerated from source files by setuptools.
8 8
 /*.egg-info
9 9
 /*.egg
10
+
11
+# Type checking cache
12
+.mypy_cache

+ 6 - 0
.isort.cfg

@@ -0,0 +1,6 @@
1
+[settings]
2
+multi_line_output=3
3
+include_trailing_comma=True
4
+force_grid_wrap=0
5
+use_parentheses=True
6
+line_length=88

+ 1 - 0
.python-version

@@ -0,0 +1 @@
1
+3.7.5

+ 354 - 0
piket_client/cli.py

@@ -0,0 +1,354 @@
1
+from typing import Optional
2
+
3
+import click
4
+from prettytable import PrettyTable
5
+
6
+from piket_client.model import (
7
+    AardbeiActivity,
8
+    ServerStatus,
9
+    NetworkError,
10
+    Consumption,
11
+    AardbeiPeopleDiff,
12
+    Person,
13
+    Settlement,
14
+    ConsumptionType,
15
+)
16
+
17
+
18
+@click.group()
19
+def cli():
20
+    """Poke coco from the command line."""
21
+    pass
22
+
23
+
24
+@cli.command()
25
+def status():
26
+    """Show the current status of the server."""
27
+
28
+    status = ServerStatus.is_server_running()
29
+
30
+    if isinstance(status, NetworkError):
31
+        print_error(f"Failed to get data from server, error {status.value}")
32
+        return
33
+
34
+    print_ok("Server is available.")
35
+
36
+    open_consumptions = ServerStatus.unsettled_consumptions()
37
+
38
+    if isinstance(open_consumptions, NetworkError):
39
+        print_error(
40
+            f"Failed to get unsettled consumptions, error {open_consumptions.value}"
41
+        )
42
+        return
43
+
44
+    click.echo(f"There are {open_consumptions.amount} unsettled consumptions.")
45
+
46
+    if open_consumptions.amount > 0:
47
+        click.echo(f"First at: {open_consumptions.first_timestamp.strftime('%c')}")
48
+        click.echo(f"Most recent at: {open_consumptions.last_timestamp.strftime('%c')}")
49
+
50
+
51
+@cli.group()
52
+def people():
53
+    pass
54
+
55
+
56
+@people.command("list")
57
+@click.option("--active/--inactive", default=None)
58
+def list_people(active: bool) -> None:
59
+    people = Person.get_all(active=active)
60
+
61
+    if isinstance(people, NetworkError):
62
+        print_error(f"Could not get people: {people.value}")
63
+        return
64
+
65
+    table = PrettyTable()
66
+    table.field_names = ["ID", "Full name", "Display name", "Active"]
67
+    table.align["ID"] = "r"
68
+    table.align["Full name"] = "l"
69
+    table.align["Display name"] = "l"
70
+    table.sortby = "Full name"
71
+
72
+    for p in people:
73
+        table.add_row([p.person_id, p.full_name, p.display_name, p.active])
74
+
75
+    print(table)
76
+
77
+
78
+@people.command("create")
79
+@click.option("--display-name", type=click.STRING)
80
+@click.argument("name", type=click.STRING)
81
+def create_person(name: str, display_name: str) -> None:
82
+    """Create a person."""
83
+    person = Person(full_name=name, display_name=display_name).create()
84
+
85
+    if isinstance(person, NetworkError):
86
+        print_error(f"Could not create Person: {person.value}")
87
+        return
88
+
89
+    print_ok(f'Created person "{name}" with ID {person.person_id}.')
90
+
91
+
92
+@people.command("rename")
93
+@click.argument("person-id", type=click.INT)
94
+@click.option("--new-full-name", type=click.STRING)
95
+@click.option("--new-display-name", type=click.STRING)
96
+def rename_person(
97
+    person_id: int, new_full_name: Optional[str], new_display_name: Optional[str],
98
+) -> None:
99
+
100
+    person = Person.get(person_id)
101
+
102
+    if person is None:
103
+        raise click.UsageError(f"Cannot find Person {person_id}!")
104
+
105
+    if new_full_name is None and new_display_name is None:
106
+        raise click.UsageError("No new full name or display name specified!")
107
+
108
+    new_person = person.rename(
109
+        new_full_name=new_full_name, new_display_name=new_display_name
110
+    )
111
+
112
+
113
+@cli.group()
114
+def settlements():
115
+    pass
116
+
117
+
118
+@settlements.command("show")
119
+@click.argument("settlement_id", type=click.INT)
120
+def show_settlement(settlement_id: int) -> None:
121
+    """Get and view the contents of a Settlement."""
122
+    s = Settlement.get(settlement_id)
123
+
124
+    if isinstance(s, NetworkError):
125
+        print_error(f"Could not get Settlement: {s.value}")
126
+        return
127
+
128
+    output_settlement_info(s)
129
+
130
+
131
+@settlements.command("create")
132
+@click.argument("name")
133
+def create_settlement(name: str) -> None:
134
+    """Create a new Settlement."""
135
+    s = Settlement.create(name)
136
+
137
+    if isinstance(s, NetworkError):
138
+        print_error(f"Could not create Settlement: {s.value}")
139
+        return
140
+
141
+    output_settlement_info(s)
142
+
143
+
144
+def output_settlement_info(s: Settlement) -> None:
145
+    click.echo(f'Settlement {s.settlement_id}, "{s.name}"')
146
+
147
+    click.echo(f"Summary:")
148
+    for key, value in s.consumption_summary.items():
149
+        click.echo(f" - {value['count']} {value['name']} ({key})")
150
+
151
+    ct_name_by_id = {key: value["name"] for key, value in s.consumption_summary.items()}
152
+
153
+    table = PrettyTable()
154
+    table.field_names = ["Name", *ct_name_by_id.values()]
155
+    table.sortby = "Name"
156
+    table.align = "r"
157
+    table.align["Name"] = "l"  # type: ignore
158
+
159
+    zero_fields = {k: "" for k in ct_name_by_id.values()}
160
+
161
+    for item in s.per_person_counts.values():
162
+        r = {"Name": item["full_name"], **zero_fields}
163
+        for key, value in item["counts"].items():
164
+            r[ct_name_by_id[key]] = value
165
+
166
+        table.add_row(r.values())
167
+
168
+    print(table)
169
+
170
+
171
+@cli.group()
172
+def consumption_types():
173
+    pass
174
+
175
+
176
+@consumption_types.command("list")
177
+def list_consumption_types() -> None:
178
+    active = ConsumptionType.get_all(active=True)
179
+    inactive = ConsumptionType.get_all(active=False)
180
+
181
+    if isinstance(active, NetworkError) or isinstance(inactive, NetworkError):
182
+        print_error("Could not get consumption types!")
183
+        return
184
+
185
+    table = PrettyTable()
186
+    table.field_names = ["ID", "Name", "Active"]
187
+    table.sortby = "ID"
188
+
189
+    for ct in active + inactive:
190
+        table.add_row([ct.consumption_type_id, ct.name, ct.active])
191
+
192
+    print(table)
193
+
194
+
195
+@consumption_types.command("create")
196
+@click.argument("name")
197
+def create_consumption_type(name: str) -> None:
198
+    ct = ConsumptionType(name=name).create()
199
+
200
+    if not isinstance(ct, NetworkError):
201
+        print_ok(f'Created consumption type "{name}" with ID {ct.consumption_type_id}.')
202
+
203
+
204
+@consumption_types.command("activate")
205
+@click.argument("consumption_type_id", type=click.INT)
206
+def activate_consumption_type(consumption_type_id: int) -> None:
207
+    ct = ConsumptionType.get(consumption_type_id)
208
+
209
+    if isinstance(ct, NetworkError):
210
+        print_error(f"Could not get ConsumptionType: {ct.value}")
211
+        return
212
+
213
+    result = ct.set_active(True)
214
+
215
+    if not isinstance(result, NetworkError):
216
+        print_ok(
217
+            f"Consumption type {ct.consumption_type_id} ({ct.name}) is now active."
218
+        )
219
+
220
+
221
+@consumption_types.command("deactivate")
222
+@click.argument("consumption_type_id", type=click.INT)
223
+def deactivate_consumption_type(consumption_type_id: int) -> None:
224
+    ct = ConsumptionType.get(consumption_type_id)
225
+
226
+    if isinstance(ct, NetworkError):
227
+        print_error(f"Could not get ConsumptionType: {ct.value}")
228
+        return
229
+
230
+    result = ct.set_active(False)
231
+
232
+    if not isinstance(result, NetworkError):
233
+        print_ok(
234
+            f"Consumption type {ct.consumption_type_id} ({ct.name}) is now inactive."
235
+        )
236
+
237
+
238
+def print_ok(msg: str) -> None:
239
+    click.echo(click.style(msg, fg="green"))
240
+
241
+
242
+def print_error(msg: str) -> None:
243
+    click.echo(click.style(msg, fg="red", bold=True), err=True)
244
+
245
+
246
+@cli.group()
247
+@click.option("--token", required=True, envvar="AARDBEI_TOKEN")
248
+@click.option("--endpoint", default="http://localhost:3000", envvar="AARDBEI_ENDPOINT")
249
+@click.pass_context
250
+def aardbei(ctx, token: str, endpoint: str) -> None:
251
+    ctx.ensure_object(dict)
252
+    ctx.obj["AardbeiToken"] = token
253
+    ctx.obj["AardbeiEndpoint"] = endpoint
254
+
255
+
256
+@aardbei.group("activities")
257
+def aardbei_activities() -> None:
258
+    pass
259
+
260
+
261
+@aardbei_activities.command("list")
262
+@click.pass_context
263
+def aardbei_list_activities(ctx) -> None:
264
+    acts = AardbeiActivity.get_available(
265
+        token=ctx.obj["AardbeiToken"], endpoint=ctx.obj["AardbeiEndpoint"]
266
+    )
267
+
268
+    if isinstance(acts, NetworkError):
269
+        print_error(f"Could not get activities: {acts.value}")
270
+        return
271
+
272
+    table = PrettyTable()
273
+    table.field_names = ["ID", "Name"]
274
+    table.align = "l"
275
+
276
+    for a in acts:
277
+        table.add_row([a.aardbei_id, a.name])
278
+
279
+    print(table)
280
+
281
+
282
+@aardbei_activities.command("apply")
283
+@click.argument("activity_id", type=click.INT)
284
+@click.pass_context
285
+def aardbei_apply_activity(ctx, activity_id: int) -> None:
286
+    result = AardbeiActivity.apply_activity(
287
+        token=ctx.obj["AardbeiToken"],
288
+        endpoint=ctx.obj["AardbeiEndpoint"],
289
+        activity_id=activity_id,
290
+    )
291
+
292
+    if isinstance(result, NetworkError):
293
+        print_error("Failed to apply activity: {result.value}")
294
+        return
295
+
296
+    print_ok(f"Activity applied. There are now {result} active people.")
297
+
298
+
299
+@aardbei.group("people")
300
+def aardbei_people() -> None:
301
+    pass
302
+
303
+
304
+@aardbei_people.command("diff")
305
+@click.pass_context
306
+def aardbei_diff_people(ctx) -> None:
307
+    diff = AardbeiPeopleDiff.get_diff(
308
+        token=ctx.obj["AardbeiToken"], endpoint=ctx.obj["AardbeiEndpoint"]
309
+    )
310
+
311
+    if isinstance(diff, NetworkError):
312
+        print_error(f"Could not get differences: {diff.value}")
313
+        return
314
+
315
+    if diff.num_changes == 0:
316
+        print_ok("There are no changes to apply.")
317
+        return
318
+
319
+    click.echo(f"There are {diff.num_changes} pending changes:")
320
+    show_diff(diff)
321
+
322
+
323
+@aardbei_people.command("sync")
324
+@click.pass_context
325
+def aardbei_sync_people(ctx) -> None:
326
+    diff = AardbeiPeopleDiff.sync(
327
+        token=ctx.obj["AardbeiToken"], endpoint=ctx.obj["AardbeiEndpoint"]
328
+    )
329
+
330
+    if isinstance(diff, NetworkError):
331
+        print_error(f"Could not apply differences: {diff.value}")
332
+        return
333
+
334
+    if diff.num_changes == 0:
335
+        print_ok("There were no changes to apply.")
336
+        return
337
+
338
+    print_ok(f"Applied {diff.num_changes} pending changes:")
339
+    show_diff(diff)
340
+
341
+
342
+def show_diff(diff: AardbeiPeopleDiff) -> None:
343
+    for name in diff.new_people:
344
+        click.echo(f" - Create local Person for {name}")
345
+
346
+    for name in diff.link_existing:
347
+        click.echo(f" - Link local and remote people for {name}")
348
+
349
+    for name in diff.altered_name:
350
+        click.echo(f" - Process name change for {name}")
351
+
352
+
353
+if __name__ == "__main__":
354
+    cli()

+ 55 - 27
piket_client/gui.py

@@ -2,9 +2,12 @@
2 2
 Provides the graphical front-end for Piket.
3 3
 """
4 4
 import collections
5
+import itertools
5 6
 import logging
7
+import math
6 8
 import os
7 9
 import sys
10
+from typing import Deque, Iterator
8 11
 
9 12
 import qdarkstyle
10 13
 
@@ -24,7 +27,8 @@ from PySide2.QtWidgets import (
24 27
     QWidget,
25 28
 )
26 29
 from PySide2.QtGui import QIcon
27
-from PySide2.QtCore import QObject, QSize, Qt, Signal, Slot
30
+from PySide2.QtCore import QObject, QSize, Qt, Signal, Slot, QUrl
31
+from PySide2.QtMultimedia import QSoundEffect
28 32
 
29 33
 # pylint: enable=E0611
30 34
 
@@ -33,12 +37,13 @@ try:
33 37
 except ImportError:
34 38
     dbus = None
35 39
 
36
-from piket_client.sound import PLOP_WAVE, UNDO_WAVE
40
+from piket_client.sound import PLOP_PATH, UNDO_PATH
37 41
 from piket_client.model import (
38 42
     Person,
39 43
     ConsumptionType,
40 44
     Consumption,
41 45
     ServerStatus,
46
+    NetworkError,
42 47
     Settlement,
43 48
 )
44 49
 import piket_client.logger
@@ -46,11 +51,6 @@ import piket_client.logger
46 51
 LOG = logging.getLogger(__name__)
47 52
 
48 53
 
49
-def plop() -> None:
50
-    """ Asynchronously play the plop sound. """
51
-    PLOP_WAVE.play()
52
-
53
-
54 54
 class NameButton(QPushButton):
55 55
     """ Wraps a QPushButton to provide a counter. """
56 56
 
@@ -77,7 +77,7 @@ class NameButton(QPushButton):
77 77
     @Slot()
78 78
     def rebuild(self) -> None:
79 79
         """ Refresh the Person object and the label. """
80
-        self.person = self.person.reload()
80
+        self.person = self.person.reload()  # type: ignore
81 81
         self.setText(self.current_label)
82 82
 
83 83
     @property
@@ -96,7 +96,7 @@ class NameButton(QPushButton):
96 96
         LOG.debug("Button clicked.")
97 97
         result = self.person.add_consumption(self.active_id)
98 98
         if result:
99
-            plop()
99
+            self.window().play_plop()
100 100
             self.setText(self.current_label)
101 101
             self.consumption_created.emit(result)
102 102
         else:
@@ -148,7 +148,8 @@ class NameButtons(QWidget):
148 148
         LOG.debug("Initializing NameButtons.")
149 149
 
150 150
         ps = Person.get_all(True)
151
-        num_columns = round(len(ps) / 10) + 1
151
+        assert not isinstance(ps, NetworkError)
152
+        num_columns = math.ceil(math.sqrt(len(ps)))
152 153
 
153 154
         if self.layout:
154 155
             LOG.debug("Removing %s widgets for rebuild", self.layout.count())
@@ -173,6 +174,9 @@ class PiketMainWindow(QMainWindow):
173 174
 
174 175
     consumption_type_changed = Signal(str)
175 176
 
177
+    plop_loop: Iterator[QSoundEffect]
178
+    undo_loop: Iterator[QSoundEffect]
179
+
176 180
     def __init__(self) -> None:
177 181
         LOG.debug("Initializing PiketMainWindow.")
178 182
         super().__init__()
@@ -182,7 +186,7 @@ class PiketMainWindow(QMainWindow):
182 186
         self.toolbar = None
183 187
         self.osk = None
184 188
         self.undo_action = None
185
-        self.undo_queue = collections.deque([], 15)
189
+        self.undo_queue: Deque[Consumption] = collections.deque([], 15)
186 190
         self.init_ui()
187 191
 
188 192
     def init_ui(self) -> None:
@@ -211,6 +215,7 @@ class PiketMainWindow(QMainWindow):
211 215
 
212 216
         # Initialize toolbar
213 217
         self.toolbar = QToolBar()
218
+        assert self.toolbar is not None
214 219
         self.toolbar.setToolButtonStyle(Qt.ToolButtonTextUnderIcon)
215 220
         self.toolbar.setIconSize(QSize(icon_size, icon_size))
216 221
 
@@ -238,7 +243,7 @@ class PiketMainWindow(QMainWindow):
238 243
         self.toolbar.setContextMenuPolicy(Qt.PreventContextMenu)
239 244
         self.toolbar.setFloatable(False)
240 245
         self.toolbar.setMovable(False)
241
-        self.ct_ag = QActionGroup(self.toolbar)
246
+        self.ct_ag: QActionGroup = QActionGroup(self.toolbar)
242 247
         self.ct_ag.setExclusive(True)
243 248
 
244 249
         cts = ConsumptionType.get_all()
@@ -287,6 +292,17 @@ class PiketMainWindow(QMainWindow):
287 292
 
288 293
         self.addToolBar(self.toolbar)
289 294
 
295
+        # Load sounds
296
+        plops = [QSoundEffect(self) for _ in range(7)]
297
+        for qse in plops:
298
+            qse.setSource(QUrl.fromLocalFile(str(PLOP_PATH)))
299
+        self.plop_loop = itertools.cycle(plops)
300
+
301
+        undos = [QSoundEffect(self) for _ in range(5)]
302
+        for qse in undos:
303
+            qse.setSource(QUrl.fromLocalFile(str(UNDO_PATH)))
304
+        self.undo_loop = itertools.cycle(undos)
305
+
290 306
         # Initialize main widget
291 307
         self.main_widget = NameButtons(self.ct_ag.actions()[0].data(), self)
292 308
         self.consumption_type_changed.connect(self.main_widget.consumption_type_changed)
@@ -310,6 +326,8 @@ class PiketMainWindow(QMainWindow):
310 326
         """ Ask for a new Person and register it, then rebuild the central
311 327
         widget. """
312 328
         inactive_persons = Person.get_all(False)
329
+        assert not isinstance(inactive_persons, NetworkError)
330
+
313 331
         inactive_persons.sort(key=lambda p: p.name)
314 332
         inactive_names = [p.name for p in inactive_persons]
315 333
 
@@ -330,9 +348,10 @@ class PiketMainWindow(QMainWindow):
330 348
                 person.set_active(True)
331 349
 
332 350
             else:
333
-                person = Person(name=name)
334
-                person = person.create()
351
+                person = Person(full_name=name, display_name=None,)
352
+                person.create()
335 353
 
354
+            assert self.main_widget is not None
336 355
             self.main_widget.init_ui()
337 356
 
338 357
     def add_consumption_type(self) -> None:
@@ -343,8 +362,8 @@ class PiketMainWindow(QMainWindow):
343 362
         self.hide_keyboard()
344 363
 
345 364
         if ok and name:
346
-            ct = ConsumptionType(name=name)
347
-            ct = ct.create()
365
+            ct = ConsumptionType(name=name).create()
366
+            assert not isinstance(ct, NetworkError)
348 367
 
349 368
             action = QAction(
350 369
                 self.load_icon(ct.icon or "beer_bottle.svg"), ct.name, self.ct_ag
@@ -352,6 +371,7 @@ class PiketMainWindow(QMainWindow):
352 371
             action.setCheckable(True)
353 372
             action.setData(str(ct.consumption_type_id))
354 373
 
374
+            assert self.toolbar is not None
355 375
             self.toolbar.addAction(action)
356 376
 
357 377
     def confirm_quit(self) -> None:
@@ -370,7 +390,7 @@ class PiketMainWindow(QMainWindow):
370 390
 
371 391
     def do_undo(self) -> None:
372 392
         """ Undo the last marked consumption. """
373
-        UNDO_WAVE.play()
393
+        next(self.undo_loop).play()
374 394
 
375 395
         to_undo = self.undo_queue.pop()
376 396
         LOG.warning("Undoing consumption %s", to_undo)
@@ -382,8 +402,10 @@ class PiketMainWindow(QMainWindow):
382 402
             self.undo_queue.append(to_undo)
383 403
 
384 404
         elif not self.undo_queue:
405
+            assert self.undo_action is not None
385 406
             self.undo_action.setDisabled(True)
386 407
 
408
+        assert self.main_widget is not None
387 409
         self.main_widget.init_ui()
388 410
 
389 411
     @Slot(Consumption)
@@ -412,6 +434,9 @@ class PiketMainWindow(QMainWindow):
412 434
         icon = QIcon(os.path.join(self.icons_dir, filename))
413 435
         return icon
414 436
 
437
+    def play_plop(self) -> None:
438
+        next(self.plop_loop).play()
439
+
415 440
 
416 441
 def main() -> None:
417 442
     """ Main entry point of GUI client. """
@@ -428,28 +453,31 @@ def main() -> None:
428 453
     app.setFont(font)
429 454
 
430 455
     # Test connectivity
431
-    server_running, info = ServerStatus.is_server_running()
456
+    server_running = ServerStatus.is_server_running()
432 457
 
433
-    if not server_running:
434
-        LOG.critical("Could not connect to server", extra={"info": info})
458
+    if isinstance(server_running, NetworkError):
459
+        LOG.critical("Could not connect to server, error %s", server_running.value)
435 460
         QMessageBox.critical(
436 461
             None,
437 462
             "Help er is iets kapot",
438 463
             "Kan niet starten omdat de server niet reageert, stuur een foto van "
439
-            "dit naar Maarten: " + repr(info),
464
+            "dit naar Maarten: " + repr(server_running.value),
440 465
         )
441
-        return 1
466
+        return
442 467
 
443 468
     # Load main window
444 469
     main_window = PiketMainWindow()
445 470
 
446 471
     # Test unsettled consumptions
447 472
     status = ServerStatus.unsettled_consumptions()
473
+    assert not isinstance(status, NetworkError)
448 474
 
449
-    unsettled = status["unsettled"]["amount"]
475
+    unsettled = status.amount
450 476
 
451 477
     if unsettled > 0:
452
-        first = status["unsettled"]["first"]
478
+        assert status.first_timestamp is not None
479
+
480
+        first = status.first_timestamp
453 481
         first_date = first.strftime("%c")
454 482
         ok = QMessageBox.information(
455 483
             None,
@@ -464,7 +492,7 @@ def main() -> None:
464 492
             name, ok = QInputDialog.getText(
465 493
                 None,
466 494
                 "Lijst afsluiten",
467
-                "Voer een naam in voor de lijst of druk op OK. Laat de datum " "staan.",
495
+                "Voer een naam in voor de lijst of druk op OK. Laat de datum staan.",
468 496
                 QLineEdit.Normal,
469 497
                 f"{first.strftime('%Y-%m-%d')}",
470 498
             )
@@ -476,9 +504,9 @@ def main() -> None:
476 504
                     f'{item["count"]} {item["name"]}'
477 505
                     for item in settlement.consumption_summary.values()
478 506
                 ]
479
-                info = ", ".join(info)
507
+                info2 = ", ".join(info)
480 508
                 QMessageBox.information(
481
-                    None, "Lijst afgesloten", f"VO! Op deze lijst stonden: {info}"
509
+                    None, "Lijst afgesloten", f"VO! Op deze lijst stonden: {info2}"
482 510
                 )
483 511
 
484 512
                 main_window = PiketMainWindow()

+ 363 - 127
piket_client/model.py

@@ -1,61 +1,108 @@
1 1
 """
2 2
 Provides access to the models stored in the database, via the server.
3 3
 """
4
+from __future__ import annotations
5
+
4 6
 import datetime
7
+import enum
5 8
 import logging
6
-from typing import NamedTuple, Sequence
9
+from dataclasses import dataclass
10
+from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
7 11
 from urllib.parse import urljoin
8 12
 
9 13
 import requests
10 14
 
11
-
12 15
 LOG = logging.getLogger(__name__)
13 16
 
14 17
 SERVER_URL = "http://127.0.0.1:5000"
15 18
 DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
16 19
 
17 20
 
21
+class NetworkError(enum.Enum):
22
+    """Represents errors that might occur when communicating with the server."""
23
+
24
+    HttpFailure = "http_failure"
25
+    """Returned when the server returns a non-successful status code."""
26
+
27
+    ConnectionFailure = "connection_failure"
28
+    """Returned when we can't connect to the server at all."""
29
+
30
+    InvalidData = "invalid_data"
31
+
32
+
18 33
 class ServerStatus:
19 34
     """ Provides helper classes to check whether the server is up. """
20 35
 
21 36
     @classmethod
22
-    def is_server_running(cls) -> bool:
37
+    def is_server_running(cls) -> Union[bool, NetworkError]:
23 38
         try:
24 39
             req = requests.get(urljoin(SERVER_URL, "ping"))
25
-
26
-            if req.status_code == 200:
27
-                return True, req.content
28
-            return False, req.content
40
+            req.raise_for_status()
29 41
 
30 42
         except requests.ConnectionError as ex:
31
-            return False, ex
43
+            LOG.exception(ex)
44
+            return NetworkError.ConnectionFailure
45
+
46
+        except requests.HTTPError as ex:
47
+            LOG.exception(ex)
48
+            return NetworkError.HttpFailure
49
+
50
+        return True
51
+
52
+    @dataclass(frozen=True)
53
+    class OpenConsumptions:
54
+        amount: int
55
+        first_timestamp: Optional[datetime.datetime]
56
+        last_timestamp: Optional[datetime.datetime]
32 57
 
33 58
     @classmethod
34
-    def unsettled_consumptions(cls) -> dict:
35
-        req = requests.get(urljoin(SERVER_URL, "status"))
59
+    def unsettled_consumptions(cls) -> Union[OpenConsumptions, NetworkError]:
60
+        try:
61
+            req = requests.get(urljoin(SERVER_URL, "status"))
62
+            req.raise_for_status()
63
+            data = req.json()
36 64
 
37
-        data = req.json()
65
+        except requests.ConnectionError as e:
66
+            LOG.exception(e)
67
+            return NetworkError.ConnectionFailure
38 68
 
39
-        if data["unsettled"]["amount"]:
40
-            data["unsettled"]["first"] = datetime.datetime.strptime(
41
-                data["unsettled"]["first"], DATETIME_FORMAT
42
-            )
43
-            data["unsettled"]["last"] = datetime.datetime.strptime(
44
-                data["unsettled"]["last"], DATETIME_FORMAT
69
+        except requests.HTTPError as e:
70
+            LOG.exception(e)
71
+            return NetworkError.HttpFailure
72
+
73
+        except ValueError as e:
74
+            LOG.exception(e)
75
+            return NetworkError.InvalidData
76
+
77
+        amount: int = data["unsettled"]["amount"]
78
+
79
+        if amount == 0:
80
+            return cls.OpenConsumptions(
81
+                amount=0, first_timestamp=None, last_timestamp=None
45 82
             )
46 83
 
47
-        return data
84
+        first = datetime.datetime.fromisoformat(data["unsettled"]["first"])
85
+        last = datetime.datetime.fromisoformat(data["unsettled"]["last"])
86
+
87
+        return cls.OpenConsumptions(
88
+            amount=amount, first_timestamp=first, last_timestamp=last
89
+        )
48 90
 
49 91
 
50 92
 class Person(NamedTuple):
51 93
     """ Represents a Person, as retrieved from the database. """
52 94
 
53
-    name: str
95
+    full_name: str
96
+    display_name: Optional[str]
54 97
     active: bool = True
55
-    person_id: int = None
98
+    person_id: Optional[int] = None
56 99
     consumptions: dict = {}
57 100
 
58
-    def add_consumption(self, type_id: str) -> bool:
101
+    @property
102
+    def name(self) -> str:
103
+        return self.display_name or self.full_name
104
+
105
+    def add_consumption(self, type_id: str) -> Optional[Consumption]:
59 106
         """ Register a consumption for this Person. """
60 107
         req = requests.post(
61 108
             urljoin(SERVER_URL, f"people/{self.person_id}/add_consumption/{type_id}")
@@ -70,7 +117,7 @@ class Person(NamedTuple):
70 117
                     req.status_code,
71 118
                     data,
72 119
                 )
73
-                return False
120
+                return None
74 121
 
75 122
             self.consumptions.update(data["person"]["consumptions"])
76 123
 
@@ -81,33 +128,72 @@ class Person(NamedTuple):
81 128
                 req.status_code,
82 129
                 req.content,
83 130
             )
84
-            return False
131
+            return None
85 132
 
86
-    def create(self) -> "Person":
133
+    def create(self) -> Union[Person, NetworkError]:
87 134
         """ Create a new Person from the current attributes. As tuples are
88 135
         immutable, a new Person with the correct id is returned. """
89
-        req = requests.post(
90
-            urljoin(SERVER_URL, "people"),
91
-            json={"person": {"name": self.name, "active": True}},
136
+
137
+        try:
138
+            req = requests.post(
139
+                urljoin(SERVER_URL, "people"),
140
+                json={
141
+                    "person": {
142
+                        "full_name": self.full_name,
143
+                        "display_name": self.display_name,
144
+                        "active": True,
145
+                    }
146
+                },
147
+            )
148
+            req.raise_for_status()
149
+            data = req.json()
150
+            return Person.from_dict(data["person"])
151
+
152
+        except requests.ConnectionError as e:
153
+            LOG.exception(e)
154
+            return NetworkError.ConnectionFailure
155
+
156
+        except requests.HTTPError as e:
157
+            LOG.exception(e)
158
+            return NetworkError.HttpFailure
159
+
160
+        except ValueError as e:
161
+            LOG.exception(e)
162
+            return NetworkError.InvalidData
163
+
164
+    def rename(
165
+        self, new_full_name: Optional[str], new_display_name: Optional[str]
166
+    ) -> Optional[Person]:
167
+        person_payload: Dict[str, str] = {}
168
+
169
+        if new_full_name is not None:
170
+            person_payload["full_name"] = new_full_name
171
+
172
+        if new_display_name is not None:
173
+            person_payload["display_name"] = new_display_name
174
+
175
+        req = requests.patch(
176
+            urljoin(SERVER_URL, f"people/{self.person_id}"),
177
+            json={"person": person_payload},
92 178
         )
93 179
 
94 180
         try:
95 181
             data = req.json()
96 182
         except ValueError:
97 183
             LOG.error(
98
-                "Did not get JSON on adding Person (%s): %s",
184
+                "Did not get JSON on updating Person (%s): %s",
99 185
                 req.status_code,
100 186
                 req.content,
101 187
             )
102 188
             return None
103 189
 
104
-        if "error" in data or req.status_code != 201:
105
-            LOG.error("Could not create Person (%s): %s", req.status_code, data)
190
+        if "error" in data or req.status_code != 200:
191
+            LOG.error("Could not update Person (%s): %s", req.status_code, data)
106 192
             return None
107 193
 
108 194
         return Person.from_dict(data["person"])
109 195
 
110
-    def set_active(self, new_state=True) -> "Person":
196
+    def set_active(self, new_state=True) -> Optional[Person]:
111 197
         req = requests.patch(
112 198
             urljoin(SERVER_URL, f"people/{self.person_id}"),
113 199
             json={"person": {"active": new_state}},
@@ -130,7 +216,7 @@ class Person(NamedTuple):
130 216
         return Person.from_dict(data["person"])
131 217
 
132 218
     @classmethod
133
-    def get(cls, person_id: int) -> "Person":
219
+    def get(cls, person_id: int) -> Optional[Person]:
134 220
         """ Retrieve a Person by id. """
135 221
         req = requests.get(urljoin(SERVER_URL, f"/people/{person_id}"))
136 222
 
@@ -154,35 +240,36 @@ class Person(NamedTuple):
154 240
             return None
155 241
 
156 242
     @classmethod
157
-    def get_all(cls, active=None) -> ["Person"]:
243
+    def get_all(cls, active=None) -> Union[List[Person], NetworkError]:
158 244
         """ Get all active People. """
159 245
         params = {}
160 246
         if active is not None:
161 247
             params["active"] = int(active)
162 248
 
163
-        req = requests.get(urljoin(SERVER_URL, "/people"), params=params)
164
-
165 249
         try:
250
+            req = requests.get(urljoin(SERVER_URL, "/people"), params=params)
251
+            req.raise_for_status()
166 252
             data = req.json()
253
+            return [Person.from_dict(item) for item in data["people"]]
167 254
 
168
-            if "error" in data:
169
-                LOG.warning("Could not get people (%s): %s", req.status_code, data)
255
+        except requests.ConnectionError as e:
256
+            LOG.exception(e)
257
+            return NetworkError.ConnectionFailure
170 258
 
171
-            return [Person.from_dict(item) for item in data["people"]]
259
+        except requests.HTTPError as e:
260
+            LOG.exception(e)
261
+            return NetworkError.HttpFailure
172 262
 
173
-        except ValueError:
174
-            LOG.error(
175
-                "Did not get JSON from server on getting People (%s): %s",
176
-                req.status_code,
177
-                req.content,
178
-            )
179
-            return None
263
+        except ValueError as e:
264
+            LOG.exception(e)
265
+            return NetworkError.InvalidData
180 266
 
181 267
     @classmethod
182 268
     def from_dict(cls, data: dict) -> "Person":
183 269
         """ Reconstruct a Person object from a dict. """
184 270
         return Person(
185
-            name=data["name"],
271
+            full_name=data["full_name"],
272
+            display_name=data["display_name"],
186 273
             active=data["active"],
187 274
             person_id=data["person_id"],
188 275
             consumptions=data["consumptions"],
@@ -206,7 +293,7 @@ class Export(NamedTuple):
206 293
         )
207 294
 
208 295
     @classmethod
209
-    def get_all(cls) -> ["Export"]:
296
+    def get_all(cls) -> Optional[List[Export]]:
210 297
         """ Get a list of all existing Exports. """
211 298
         req = requests.get(urljoin(SERVER_URL, "exports"))
212 299
 
@@ -227,7 +314,7 @@ class Export(NamedTuple):
227 314
         return [cls.from_dict(e) for e in data["exports"]]
228 315
 
229 316
     @classmethod
230
-    def get(cls, export_id: int) -> "Export":
317
+    def get(cls, export_id: int) -> Optional[Export]:
231 318
         """ Retrieve one Export. """
232 319
         req = requests.get(urljoin(SERVER_URL, f"exports/{export_id}"))
233 320
 
@@ -250,7 +337,7 @@ class Export(NamedTuple):
250 337
         return cls.from_dict(data["export"])
251 338
 
252 339
     @classmethod
253
-    def create(cls) -> "Export":
340
+    def create(cls) -> Optional[Export]:
254 341
         """ Create a new Export, containing all un-exported Settlements. """
255 342
         req = requests.post(urljoin(SERVER_URL, "exports"))
256 343
 
@@ -277,87 +364,85 @@ class ConsumptionType(NamedTuple):
277 364
     """ Represents a stored ConsumptionType. """
278 365
 
279 366
     name: str
280
-    consumption_type_id: int = None
281
-    icon: str = None
367
+    consumption_type_id: Optional[int] = None
368
+    icon: Optional[str] = None
369
+    active: bool = True
282 370
 
283
-    def create(self) -> "ConsumptionType":
371
+    def create(self) -> Union[ConsumptionType, NetworkError]:
284 372
         """ Create a new ConsumptionType from the current attributes. As tuples
285 373
         are immutable, a new ConsumptionType with the correct id is returned.
286 374
         """
287
-        req = requests.post(
288
-            urljoin(SERVER_URL, "consumption_types"),
289
-            json={"consumption_type": {"name": self.name, "icon": self.icon}},
290
-        )
291
-
292 375
         try:
293
-            data = req.json()
294
-        except ValueError:
295
-            LOG.error(
296
-                "Did not get JSON on adding ConsumptionType (%s): %s",
297
-                req.status_code,
298
-                req.content,
376
+            req = requests.post(
377
+                urljoin(SERVER_URL, "consumption_types"),
378
+                json={"consumption_type": {"name": self.name, "icon": self.icon}},
299 379
             )
300
-            return None
301 380
 
302
-        if "error" in data or req.status_code != 201:
303
-            LOG.error(
304
-                "Could not create ConsumptionType (%s): %s", req.status_code, data
305
-            )
306
-            return None
381
+            req.raise_for_status()
382
+            data = req.json()
383
+            return ConsumptionType.from_dict(data["consumption_type"])
307 384
 
308
-        return ConsumptionType.from_dict(data["consumption_type"])
385
+        except requests.ConnectionError as e:
386
+            LOG.exception(e)
387
+            return NetworkError.ConnectionFailure
388
+
389
+        except requests.HTTPError as e:
390
+            LOG.exception(e)
391
+            return NetworkError.HttpFailure
392
+
393
+        except ValueError as e:
394
+            LOG.exception(e)
395
+            return NetworkError.InvalidData
309 396
 
310 397
     @classmethod
311
-    def get(cls, consumption_type_id: int) -> "ConsumptionType":
398
+    def get(cls, consumption_type_id: int) -> Union[ConsumptionType, NetworkError]:
312 399
         """ Retrieve a ConsumptionType by id. """
313
-        req = requests.get(
314
-            urljoin(SERVER_URL, f"/consumption_types/{consumption_type_id}")
315
-        )
316
-
317 400
         try:
401
+            req = requests.get(
402
+                urljoin(SERVER_URL, f"/consumption_types/{consumption_type_id}")
403
+            )
404
+            req.raise_for_status()
318 405
             data = req.json()
319 406
 
320
-            if "error" in data:
321
-                LOG.warning(
322
-                    "Could not get consumption type %s (%s): %s",
323
-                    consumption_type_id,
324
-                    req.status_code,
325
-                    data,
326
-                )
327
-                return None
407
+        except requests.ConnectionError as e:
408
+            LOG.exception(e)
409
+            return NetworkError.ConnectionFailure
328 410
 
329
-            return cls.from_dict(data["consumption_type"])
411
+        except requests.HTTPError as e:
412
+            LOG.exception(e)
413
+            return NetworkError.HttpFailure
330 414
 
331
-        except ValueError:
332
-            LOG.error(
333
-                "Did not get JSON from server on getting consumption type (%s): %s",
334
-                req.status_code,
335
-                req.content,
336
-            )
337
-            return None
415
+        except ValueError as e:
416
+            LOG.exception(e)
417
+            return NetworkError.InvalidData
338 418
 
339
-    @classmethod
340
-    def get_all(cls) -> ["ConsumptionType"]:
341
-        """ Get all active ConsumptionTypes. """
342
-        req = requests.get(urljoin(SERVER_URL, "/consumption_types"))
419
+        return cls.from_dict(data["consumption_type"])
343 420
 
421
+    @classmethod
422
+    def get_all(cls, active: bool = True) -> Union[List[ConsumptionType], NetworkError]:
423
+        """ Get the list of ConsumptionTypes. """
344 424
         try:
425
+            req = requests.get(
426
+                urljoin(SERVER_URL, "/consumption_types"),
427
+                params={"active": int(active)},
428
+            )
429
+            req.raise_for_status()
430
+
345 431
             data = req.json()
346 432
 
347
-            if "error" in data:
348
-                LOG.warning(
349
-                    "Could not get consumption types (%s): %s", req.status_code, data
350
-                )
433
+        except requests.ConnectionError as e:
434
+            LOG.exception(e)
435
+            return NetworkError.ConnectionFailure
351 436
 
352
-            return [cls.from_dict(item) for item in data["consumption_types"]]
437
+        except requests.HTTPError as e:
438
+            LOG.exception(e)
439
+            return NetworkError.HttpFailure
353 440
 
354
-        except ValueError:
355
-            LOG.error(
356
-                "Did not get JSON from server on getting ConsumptionTypes (%s): %s",
357
-                req.status_code,
358
-                req.content,
359
-            )
360
-            return None
441
+        except ValueError as e:
442
+            LOG.exception(e)
443
+            return NetworkError.InvalidData
444
+
445
+        return [cls.from_dict(x) for x in data["consumption_types"]]
361 446
 
362 447
     @classmethod
363 448
     def from_dict(cls, data: dict) -> "ConsumptionType":
@@ -366,8 +451,33 @@ class ConsumptionType(NamedTuple):
366 451
             name=data["name"],
367 452
             consumption_type_id=data["consumption_type_id"],
368 453
             icon=data.get("icon"),
454
+            active=data["active"],
369 455
         )
370 456
 
457
+    def set_active(self, active: bool) -> Union[ConsumptionType, NetworkError]:
458
+        """Update the 'active' attribute."""
459
+        try:
460
+            req = requests.patch(
461
+                urljoin(SERVER_URL, f"/consumption_types/{self.consumption_type_id}"),
462
+                json={"consumption_type": {"active": active}},
463
+            )
464
+            req.raise_for_status()
465
+            data = req.json()
466
+
467
+        except requests.ConnectionError as e:
468
+            LOG.exception(e)
469
+            return NetworkError.ConnectionFailure
470
+
471
+        except requests.HTTPError as e:
472
+            LOG.exception(e)
473
+            return NetworkError.HttpFailure
474
+
475
+        except ValueError as e:
476
+            LOG.exception(e)
477
+            return NetworkError.InvalidData
478
+
479
+        return self.from_dict(data["consumption_type"])
480
+
371 481
 
372 482
 class Consumption(NamedTuple):
373 483
     """ Represents a stored Consumption. """
@@ -377,7 +487,7 @@ class Consumption(NamedTuple):
377 487
     consumption_type_id: int
378 488
     created_at: datetime.datetime
379 489
     reversed: bool = False
380
-    settlement_id: int = None
490
+    settlement_id: Optional[int] = None
381 491
 
382 492
     @classmethod
383 493
     def from_dict(cls, data: dict) -> "Consumption":
@@ -391,7 +501,7 @@ class Consumption(NamedTuple):
391 501
             reversed=data["reversed"],
392 502
         )
393 503
 
394
-    def reverse(self) -> "Consumption":
504
+    def reverse(self) -> Optional[Consumption]:
395 505
         """ Reverse this consumption. """
396 506
         req = requests.delete(
397 507
             urljoin(SERVER_URL, f"/consumptions/{self.consumption_id}")
@@ -407,7 +517,7 @@ class Consumption(NamedTuple):
407 517
                     req.status_code,
408 518
                     data,
409 519
                 )
410
-                return False
520
+                return None
411 521
 
412 522
             return Consumption.from_dict(data["consumption"])
413 523
 
@@ -417,7 +527,7 @@ class Consumption(NamedTuple):
417 527
                 req.status_code,
418 528
                 req.content,
419 529
             )
420
-            return False
530
+            return None
421 531
 
422 532
 
423 533
 class Settlement(NamedTuple):
@@ -425,8 +535,9 @@ class Settlement(NamedTuple):
425 535
 
426 536
     settlement_id: int
427 537
     name: str
428
-    consumption_summary: dict
429
-    count_info: dict = {}
538
+    consumption_summary: Dict[str, Any]
539
+    count_info: Dict[str, Any] = {}
540
+    per_person_counts: Dict[str, Any] = {}
430 541
 
431 542
     @classmethod
432 543
     def from_dict(cls, data: dict) -> "Settlement":
@@ -434,7 +545,8 @@ class Settlement(NamedTuple):
434 545
             settlement_id=data["settlement_id"],
435 546
             name=data["name"],
436 547
             consumption_summary=data["consumption_summary"],
437
-            count_info=data.get("count_info", {}),
548
+            count_info=data["count_info"],
549
+            per_person_counts=data["per_person_counts"],
438 550
         )
439 551
 
440 552
     @classmethod
@@ -446,23 +558,147 @@ class Settlement(NamedTuple):
446 558
         return cls.from_dict(req.json()["settlement"])
447 559
 
448 560
     @classmethod
449
-    def get(cls, settlement_id: int) -> "Settlement":
450
-        req = requests.get(urljoin(SERVER_URL, f"/settlements/{settlement_id}"))
451
-
561
+    def get(cls, settlement_id: int) -> Union[Settlement, NetworkError]:
452 562
         try:
563
+            req = requests.get(urljoin(SERVER_URL, f"/settlements/{settlement_id}"))
564
+            req.raise_for_status()
453 565
             data = req.json()
454
-        except ValueError:
455
-            LOG.error(
456
-                "Did not get JSON on retrieving Settlement (%s): %s",
457
-                req.status_code,
458
-                req.content,
459
-            )
460
-            return None
461 566
 
462
-        if "error" in data or req.status_code != 200:
463
-            LOG.error("Could not get Export (%s): %s", req.status_code, data)
464
-            return None
567
+        except ValueError as e:
568
+            LOG.exception(e)
569
+            return NetworkError.InvalidData
570
+
571
+        except requests.ConnectionError as e:
572
+            LOG.exception(e)
573
+            return NetworkError.ConnectionFailure
574
+
575
+        except requests.HTTPError as e:
576
+            LOG.exception(e)
577
+            return NetworkError.HttpFailure
465 578
 
466 579
         data["settlement"]["count_info"] = data["count_info"]
467 580
 
468 581
         return cls.from_dict(data["settlement"])
582
+
583
+
584
+@dataclass(frozen=True)
585
+class AardbeiActivity:
586
+    aardbei_id: int
587
+    name: str
588
+
589
+    @classmethod
590
+    def from_dict(cls, data: Dict[str, Any]) -> AardbeiActivity:
591
+        return cls(data["activity"]["id"], data["activity"]["name"])
592
+
593
+    @classmethod
594
+    def get_available(
595
+        cls, token: str, endpoint: str
596
+    ) -> Union[List[AardbeiActivity], NetworkError]:
597
+        try:
598
+            req = requests.post(
599
+                urljoin(SERVER_URL, "/aardbei/get_activities"),
600
+                json={"endpoint": endpoint, "token": token},
601
+            )
602
+
603
+            req.raise_for_status()
604
+            return [cls.from_dict(x) for x in req.json()["activities"]]
605
+
606
+        except requests.ConnectionError as e:
607
+            LOG.exception(e)
608
+            return NetworkError.ConnectionFailure
609
+
610
+        except requests.HTTPError as e:
611
+            LOG.exception(e)
612
+            return NetworkError.HttpFailure
613
+
614
+        except ValueError as e:
615
+            LOG.exception(e)
616
+            return NetworkError.InvalidData
617
+
618
+    @classmethod
619
+    def apply_activity(
620
+        cls, token: str, endpoint: str, activity_id: int
621
+    ) -> Union[int, NetworkError]:
622
+        try:
623
+            req = requests.post(
624
+                urljoin(SERVER_URL, "/aardbei/apply_activity"),
625
+                json={"activity_id": activity_id, "token": token, "endpoint": endpoint},
626
+            )
627
+            req.raise_for_status()
628
+            data = req.json()
629
+
630
+            return data["activity"]["response_counts"]["present"]
631
+
632
+        except requests.ConnectionError as e:
633
+            LOG.exception(e)
634
+            return NetworkError.ConnectionFailure
635
+
636
+        except requests.HTTPError as e:
637
+            LOG.exception(e)
638
+            return NetworkError.HttpFailure
639
+
640
+        except ValueError as e:
641
+            LOG.exception(e)
642
+            return NetworkError.InvalidData
643
+
644
+
645
+@dataclass(frozen=True)
646
+class AardbeiPeopleDiff:
647
+    altered_name: List[str]
648
+    link_existing: List[str]
649
+    new_people: List[str]
650
+    num_changes: int
651
+
652
+    @classmethod
653
+    def from_dict(cls, data: Dict[str, Any]) -> AardbeiPeopleDiff:
654
+        return cls(**data)
655
+
656
+    @classmethod
657
+    def get_diff(
658
+        cls, token: str, endpoint: str
659
+    ) -> Union[AardbeiPeopleDiff, NetworkError]:
660
+        try:
661
+            req = requests.post(
662
+                urljoin(SERVER_URL, "/aardbei/diff_people"),
663
+                json={"endpoint": endpoint, "token": token},
664
+            )
665
+            req.raise_for_status()
666
+            data = req.json()
667
+
668
+            return cls.from_dict(data)
669
+
670
+        except requests.ConnectionError as e:
671
+            LOG.exception(e)
672
+            return NetworkError.ConnectionFailure
673
+
674
+        except requests.HTTPError as e:
675
+            LOG.exception(e)
676
+            return NetworkError.HttpFailure
677
+
678
+        except ValueError as e:
679
+            LOG.exception(e)
680
+            return NetworkError.InvalidData
681
+
682
+    @classmethod
683
+    def sync(cls, token: str, endpoint: str) -> Union[AardbeiPeopleDiff, NetworkError]:
684
+        try:
685
+            req = requests.post(
686
+                urljoin(SERVER_URL, "/aardbei/sync_people"),
687
+                json={"endpoint": endpoint, "token": token},
688
+            )
689
+            req.raise_for_status()
690
+            data = req.json()
691
+
692
+            return cls.from_dict(data)
693
+
694
+        except requests.ConnectionError as e:
695
+            LOG.exception(e)
696
+            return NetworkError.ConnectionFailure
697
+
698
+        except requests.HTTPError as e:
699
+            LOG.exception(e)
700
+            return NetworkError.HttpFailure
701
+
702
+        except ValueError as e:
703
+            LOG.exception(e)
704
+            return NetworkError.InvalidData

+ 13 - 14
piket_client/set_active.py

@@ -2,9 +2,13 @@
2 2
 Provides a helper tool to (de-)activate multiple people at once.
3 3
 """
4 4
 
5
+import math
5 6
 import sys
6 7
 
7 8
 # pylint: disable=E0611
9
+import qdarkstyle
10
+from PySide2.QtCore import QObject, QSize, Qt, Signal, Slot
11
+from PySide2.QtGui import QIcon
8 12
 from PySide2.QtWidgets import (
9 13
     QAction,
10 14
     QActionGroup,
@@ -19,14 +23,10 @@ from PySide2.QtWidgets import (
19 23
     QToolBar,
20 24
     QWidget,
21 25
 )
22
-from PySide2.QtGui import QIcon
23
-from PySide2.QtCore import QObject, QSize, Qt, Signal, Slot
24 26
 
25
-# pylint: enable=E0611
27
+from piket_client.model import NetworkError, Person, ServerStatus
26 28
 
27
-import qdarkstyle
28
-
29
-from piket_client.model import Person, ServerStatus
29
+# pylint: enable=E0611
30 30
 
31 31
 
32 32
 class ActivationButton(QPushButton):
@@ -55,7 +55,8 @@ class ActivationButtons(QWidget):
55 55
 
56 56
     def init_ui(self) -> None:
57 57
         ps = Person.get_all()
58
-        num_columns = round(len(ps) / 10) + 1
58
+        assert not isinstance(ps, NetworkError)
59
+        num_columns = math.ceil(math.sqrt(len(ps)))
59 60
 
60 61
         for index, person in enumerate(ps):
61 62
             button = ActivationButton(person, self)
@@ -66,7 +67,7 @@ class ActiveStateMainWindow(QMainWindow):
66 67
     def __init__(self) -> None:
67 68
         super().__init__()
68 69
 
69
-        self.toolbar = None
70
+        self.toolbar = QToolBar()
70 71
 
71 72
         self.init_ui()
72 73
 
@@ -79,7 +80,6 @@ class ActiveStateMainWindow(QMainWindow):
79 80
         icon_size = font_metrics.height() * 1.45
80 81
 
81 82
         # Toolbar
82
-        self.toolbar = QToolBar()
83 83
         self.toolbar.setToolButtonStyle(Qt.ToolButtonTextUnderIcon)
84 84
         self.toolbar.setIconSize(QSize(icon_size, icon_size))
85 85
 
@@ -112,17 +112,16 @@ def main() -> None:
112 112
     app.setFont(font)
113 113
 
114 114
     # Test connectivity
115
-    server_running, info = ServerStatus.is_server_running()
115
+    server_running = ServerStatus.is_server_running()
116 116
 
117
-    if not server_running:
118
-        LOG.critical("Could not connect to server", extra={"info": info})
117
+    if not isinstance(server_running, bool):
119 118
         QMessageBox.critical(
120 119
             None,
121 120
             "Help er is iets kapot",
122 121
             "Kan niet starten omdat de server niet reageert, stuur een foto van "
123
-            "dit naar Maarten: " + repr(info),
122
+            "dit naar Maarten: " + repr(server_running.value),
124 123
         )
125
-        return 1
124
+        return
126 125
 
127 126
     # Load main window
128 127
     main_window = ActiveStateMainWindow()

+ 6 - 8
piket_client/sound.py

@@ -2,16 +2,14 @@
2 2
 Provides functions related to playing sounds.
3 3
 """
4 4
 
5
-import os
5
+import pathlib
6 6
 
7
-import simpleaudio as sa
8 7
 
9
-
10
-SOUNDS_DIR = os.path.join(os.path.dirname(__file__), "sounds")
8
+SOUND_PATH = pathlib.Path(__file__).parent / "sounds"
11 9
 """ Contains the absolute path to the sounds directory. """
12 10
 
13
-PLOP_WAVE = sa.WaveObject.from_wave_file(os.path.join(SOUNDS_DIR, "plop.wav"))
14
-""" SimpleAudio WaveObject containing the plop sound. """
11
+PLOP_PATH = SOUND_PATH / "plop.wav"
12
+""" Path to the "plop" sound. """
15 13
 
16
-UNDO_WAVE = sa.WaveObject.from_wave_file(os.path.join(SOUNDS_DIR, "undo.wav"))
17
-""" SimpleAudio WaveObject containing the undo sound. """
14
+UNDO_PATH = SOUND_PATH / "undo.wav"
15
+""" Path to the "undo" sound". """

+ 9 - 487
piket_server/__init__.py

@@ -2,490 +2,12 @@
2 2
 Piket server, handles events generated by the client.
3 3
 """
4 4
 
5
-import datetime
6
-import os
7
-
8
-from sqlalchemy.exc import SQLAlchemyError
9
-from sqlalchemy import func
10
-from flask import Flask, jsonify, abort, request
11
-from flask_sqlalchemy import SQLAlchemy
12
-
13
-
14
-DATA_HOME = os.environ.get("XDG_DATA_HOME", "~/.local/share")
15
-CONFIG_DIR = os.path.join(DATA_HOME, "piket_server")
16
-DB_PATH = os.path.expanduser(os.path.join(CONFIG_DIR, "database.sqlite3"))
17
-DB_URL = f"sqlite:///{DB_PATH}"
18
-
19
-app = Flask("piket_server")
20
-app.config["SQLALCHEMY_DATABASE_URI"] = DB_URL
21
-app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
22
-db = SQLAlchemy(app)
23
-
24
-
25
-# ---------- Models ----------
26
-class Person(db.Model):
27
-    """ Represents a person to be shown on the lists. """
28
-
29
-    __tablename__ = "people"
30
-
31
-    person_id = db.Column(db.Integer, primary_key=True)
32
-    name = db.Column(db.String, nullable=False)
33
-    active = db.Column(db.Boolean, nullable=False, default=False)
34
-
35
-    consumptions = db.relationship("Consumption", backref="person", lazy=True)
36
-
37
-    def __repr__(self) -> str:
38
-        return f"<Person {self.person_id}: {self.name}>"
39
-
40
-    @property
41
-    def as_dict(self) -> dict:
42
-        return {
43
-            "person_id": self.person_id,
44
-            "active": self.active,
45
-            "name": self.name,
46
-            "consumptions": {
47
-                ct.consumption_type_id: Consumption.query.filter_by(person=self)
48
-                .filter_by(settlement=None)
49
-                .filter_by(consumption_type=ct)
50
-                .filter_by(reversed=False)
51
-                .count()
52
-                for ct in ConsumptionType.query.all()
53
-            },
54
-        }
55
-
56
-
57
-class Export(db.Model):
58
-    """ Represents a set of exported Settlements. """
59
-
60
-    __tablename__ = "exports"
61
-
62
-    export_id = db.Column(db.Integer, primary_key=True)
63
-    created_at = db.Column(
64
-        db.DateTime, default=datetime.datetime.utcnow, nullable=False
65
-    )
66
-
67
-    settlements = db.relationship("Settlement", backref="export", lazy=True)
68
-
69
-    @property
70
-    def as_dict(self) -> dict:
71
-        return {
72
-            "export_id": self.export_id,
73
-            "created_at": self.created_at.isoformat(),
74
-            "settlement_ids": [s.settlement_id for s in self.settlements],
75
-        }
76
-
77
-
78
-class Settlement(db.Model):
79
-    """ Represents a settlement of the list. """
80
-
81
-    __tablename__ = "settlements"
82
-
83
-    settlement_id = db.Column(db.Integer, primary_key=True)
84
-    name = db.Column(db.String, nullable=False)
85
-    export_id = db.Column(db.Integer, db.ForeignKey("exports.export_id"), nullable=True)
86
-
87
-    consumptions = db.relationship("Consumption", backref="settlement", lazy=True)
88
-
89
-    def __repr__(self) -> str:
90
-        return f"<Settlement {self.settlement_id}: {self.name}>"
91
-
92
-    @property
93
-    def as_dict(self) -> dict:
94
-        return {
95
-            "settlement_id": self.settlement_id,
96
-            "name": self.name,
97
-            "consumption_summary": self.consumption_summary,
98
-            "unique_people": self.unique_people,
99
-        }
100
-
101
-    @property
102
-    def unique_people(self) -> int:
103
-        q = (
104
-            Consumption.query.filter_by(settlement=self)
105
-            .filter_by(reversed=False)
106
-            .group_by(Consumption.person_id)
107
-            .count()
108
-        )
109
-        return q
110
-
111
-    @property
112
-    def consumption_summary(self) -> dict:
113
-        q = (
114
-            Consumption.query.filter_by(settlement=self)
115
-            .filter_by(reversed=False)
116
-            .group_by(Consumption.consumption_type_id)
117
-            .order_by(ConsumptionType.name)
118
-            .outerjoin(ConsumptionType)
119
-            .with_entities(
120
-                Consumption.consumption_type_id,
121
-                ConsumptionType.name,
122
-                func.count(Consumption.consumption_id),
123
-            )
124
-            .all()
125
-        )
126
-
127
-        return {r[0]: {"name": r[1], "count": r[2]} for r in q}
128
-
129
-    @property
130
-    def per_person(self) -> dict:
131
-        # Get keys of seen consumption_types
132
-        c_types = self.consumption_summary.keys()
133
-
134
-        result = {}
135
-        for type in c_types:
136
-            c_type = ConsumptionType.query.get(type)
137
-            result[type] = {"consumption_type": c_type.as_dict, "counts": {}}
138
-
139
-            q = (
140
-                Consumption.query.filter_by(settlement=self)
141
-                .filter_by(reversed=False)
142
-                .filter_by(consumption_type=c_type)
143
-                .group_by(Consumption.person_id)
144
-                .order_by(Person.name)
145
-                .outerjoin(Person)
146
-                .with_entities(
147
-                    Person.person_id,
148
-                    Person.name,
149
-                    func.count(Consumption.consumption_id),
150
-                )
151
-                .all()
152
-            )
153
-
154
-            for row in q:
155
-                result[type]["counts"][row[0]] = {"name": row[1], "count": row[2]}
156
-
157
-        return result
158
-
159
-
160
-class ConsumptionType(db.Model):
161
-    """ Represents a type of consumption to be counted. """
162
-
163
-    __tablename__ = "consumption_types"
164
-
165
-    consumption_type_id = db.Column(db.Integer, primary_key=True)
166
-    name = db.Column(db.String, nullable=False)
167
-    icon = db.Column(db.String)
168
-    active = db.Column(db.Boolean, default=True)
169
-
170
-    consumptions = db.relationship("Consumption", backref="consumption_type", lazy=True)
171
-
172
-    def __repr__(self) -> str:
173
-        return f"<ConsumptionType: {self.name}>"
174
-
175
-    @property
176
-    def as_dict(self) -> dict:
177
-        return {
178
-            "consumption_type_id": self.consumption_type_id,
179
-            "name": self.name,
180
-            "icon": self.icon,
181
-        }
182
-
183
-
184
-class Consumption(db.Model):
185
-    """ Represent one consumption to be counted. """
186
-
187
-    __tablename__ = "consumptions"
188
-
189
-    consumption_id = db.Column(db.Integer, primary_key=True)
190
-    person_id = db.Column(db.Integer, db.ForeignKey("people.person_id"), nullable=True)
191
-    consumption_type_id = db.Column(
192
-        db.Integer,
193
-        db.ForeignKey("consumption_types.consumption_type_id"),
194
-        nullable=False,
195
-    )
196
-    settlement_id = db.Column(
197
-        db.Integer, db.ForeignKey("settlements.settlement_id"), nullable=True
198
-    )
199
-    created_at = db.Column(
200
-        db.DateTime, default=datetime.datetime.utcnow, nullable=False
201
-    )
202
-    reversed = db.Column(db.Boolean, default=False, nullable=False)
203
-
204
-    def __repr__(self) -> str:
205
-        return f"<Consumption: {self.consumption_type.name} for {self.person.name}>"
206
-
207
-    @property
208
-    def as_dict(self) -> dict:
209
-        return {
210
-            "consumption_id": self.consumption_id,
211
-            "person_id": self.person_id,
212
-            "consumption_type_id": self.consumption_type_id,
213
-            "settlement_id": self.settlement_id,
214
-            "created_at": self.created_at.isoformat(),
215
-            "reversed": self.reversed,
216
-        }
217
-
218
-
219
-# ---------- Models ----------
220
-
221
-
222
-@app.route("/ping")
223
-def ping() -> None:
224
-    """ Return a status ping. """
225
-    return "Pong"
226
-
227
-
228
-@app.route("/status")
229
-def status() -> None:
230
-    """ Return a status dict with info about the database. """
231
-    unsettled_q = Consumption.query.filter_by(settlement=None).filter_by(reversed=False)
232
-
233
-    unsettled = unsettled_q.count()
234
-
235
-    first = None
236
-    last = None
237
-    if unsettled:
238
-        last = (
239
-            unsettled_q.order_by(Consumption.created_at.desc())
240
-            .first()
241
-            .created_at.isoformat()
242
-        )
243
-        first = (
244
-            unsettled_q.order_by(Consumption.created_at.asc())
245
-            .first()
246
-            .created_at.isoformat()
247
-        )
248
-
249
-    return jsonify({"unsettled": {"amount": unsettled, "first": first, "last": last}})
250
-
251
-
252
-# Person
253
-@app.route("/people", methods=["GET"])
254
-def get_people():
255
-    """ Return a list of currently known people. """
256
-    people = Person.query.order_by(Person.name).all()
257
-    q = Person.query.order_by(Person.name)
258
-    if request.args.get("active"):
259
-        active_status = request.args.get("active", type=int)
260
-        q = q.filter_by(active=active_status)
261
-    people = q.all()
262
-    result = [person.as_dict for person in people]
263
-    return jsonify(people=result)
264
-
265
-
266
-@app.route("/people/<int:person_id>", methods=["GET"])
267
-def get_person(person_id: int):
268
-    person = Person.query.get_or_404(person_id)
269
-
270
-    return jsonify(person=person.as_dict)
271
-
272
-
273
-@app.route("/people", methods=["POST"])
274
-def add_person():
275
-    """
276
-    Add a new person.
277
-
278
-    Required parameters:
279
-    - name (str)
280
-    """
281
-    json = request.get_json()
282
-
283
-    if not json:
284
-        return jsonify({"error": "Could not parse JSON."}), 400
285
-
286
-    data = json.get("person") or {}
287
-    person = Person(name=data.get("name"), active=data.get("active", False))
288
-
289
-    try:
290
-        db.session.add(person)
291
-        db.session.commit()
292
-    except SQLAlchemyError:
293
-        return jsonify({"error": "Invalid arguments for Person."}), 400
294
-
295
-    return jsonify(person=person.as_dict), 201
296
-
297
-
298
-@app.route("/people/<int:person_id>/add_consumption", methods=["POST"])
299
-def add_consumption(person_id: int):
300
-    person = Person.query.get_or_404(person_id)
301
-
302
-    consumption = Consumption(person=person, consumption_type_id=1)
303
-    try:
304
-        db.session.add(consumption)
305
-        db.session.commit()
306
-    except SQLAlchemyError:
307
-        return (
308
-            jsonify(
309
-                {"error": "Invalid Consumption parameters.", "person": person.as_dict}
310
-            ),
311
-            400,
312
-        )
313
-
314
-    return jsonify(person=person.as_dict, consumption=consumption.as_dict), 201
315
-
316
-
317
-@app.route("/people/<int:person_id>", methods=["PATCH"])
318
-def update_person(person_id: int):
319
-    person = Person.query.get_or_404(person_id)
320
-
321
-    data = request.json["person"]
322
-
323
-    if "active" in data:
324
-        person.active = data["active"]
325
-
326
-        db.session.add(person)
327
-        db.session.commit()
328
-
329
-        return jsonify(person=person.as_dict)
330
-
331
-
332
-@app.route("/people/<int:person_id>/add_consumption/<int:ct_id>", methods=["POST"])
333
-def add_consumption2(person_id: int, ct_id: int):
334
-    person = Person.query.get_or_404(person_id)
335
-
336
-    consumption = Consumption(person=person, consumption_type_id=ct_id)
337
-    try:
338
-        db.session.add(consumption)
339
-        db.session.commit()
340
-    except SQLAlchemyError:
341
-        return (
342
-            jsonify(
343
-                {"error": "Invalid Consumption parameters.", "person": person.as_dict}
344
-            ),
345
-            400,
346
-        )
347
-
348
-    return jsonify(person=person.as_dict, consumption=consumption.as_dict), 201
349
-
350
-
351
-@app.route("/consumptions/<int:consumption_id>", methods=["DELETE"])
352
-def reverse_consumption(consumption_id: int):
353
-    """ Reverse a consumption. """
354
-    consumption = Consumption.query.get_or_404(consumption_id)
355
-
356
-    if consumption.reversed:
357
-        return (
358
-            jsonify(
359
-                {
360
-                    "error": "Consumption already reversed",
361
-                    "consumption": consumption.as_dict,
362
-                }
363
-            ),
364
-            409,
365
-        )
366
-
367
-    try:
368
-        consumption.reversed = True
369
-        db.session.add(consumption)
370
-        db.session.commit()
371
-
372
-    except SQLAlchemyError:
373
-        return jsonify({"error": "Database error."}), 500
374
-
375
-    return jsonify(consumption=consumption.as_dict), 200
376
-
377
-
378
-# ConsumptionType
379
-@app.route("/consumption_types", methods=["GET"])
380
-def get_consumption_types():
381
-    """ Return a list of currently active consumption types. """
382
-    ctypes = ConsumptionType.query.filter_by(active=True).all()
383
-    result = [ct.as_dict for ct in ctypes]
384
-    return jsonify(consumption_types=result)
385
-
386
-
387
-@app.route("/consumption_types/<int:consumption_type_id>", methods=["GET"])
388
-def get_consumption_type(consumption_type_id: int):
389
-    ct = ConsumptionType.query.get_or_404(consumption_type_id)
390
-
391
-    return jsonify(consumption_type=ct.as_dict)
392
-
393
-
394
-@app.route("/consumption_types", methods=["POST"])
395
-def add_consumption_type():
396
-    """ Add a new ConsumptionType.  """
397
-    json = request.get_json()
398
-
399
-    if not json:
400
-        return jsonify({"error": "Could not parse JSON."}), 400
401
-
402
-    data = json.get("consumption_type") or {}
403
-    ct = ConsumptionType(name=data.get("name"), icon=data.get("icon"))
404
-
405
-    try:
406
-        db.session.add(ct)
407
-        db.session.commit()
408
-    except SQLAlchemyError:
409
-        return jsonify({"error": "Invalid arguments for ConsumptionType."}), 400
410
-
411
-    return jsonify(consumption_type=ct.as_dict), 201
412
-
413
-
414
-# Settlement
415
-@app.route("/settlements", methods=["GET"])
416
-def get_settlements():
417
-    """ Return a list of the active Settlements. """
418
-    result = Settlement.query.all()
419
-    return jsonify(settlements=[s.as_dict for s in result])
420
-
421
-
422
-@app.route("/settlements/<int:settlement_id>", methods=["GET"])
423
-def get_settlement(settlement_id: int):
424
-    """ Show full details for a single Settlement. """
425
-    s = Settlement.query.get_or_404(settlement_id)
426
-
427
-    per_person = s.per_person
428
-
429
-    return jsonify(settlement=s.as_dict, count_info=per_person)
430
-
431
-
432
-@app.route("/settlements", methods=["POST"])
433
-def add_settlement():
434
-    """ Create a Settlement, and link all un-settled Consumptions to it. """
435
-    json = request.get_json()
436
-
437
-    if not json:
438
-        return jsonify({"error": "Could not parse JSON."}), 400
439
-
440
-    data = json.get("settlement") or {}
441
-    s = Settlement(name=data["name"])
442
-
443
-    db.session.add(s)
444
-    db.session.commit()
445
-
446
-    Consumption.query.filter_by(settlement=None).update(
447
-        {"settlement_id": s.settlement_id}
448
-    )
449
-
450
-    db.session.commit()
451
-
452
-    return jsonify(settlement=s.as_dict)
453
-
454
-
455
-# Export
456
-@app.route("/exports", methods=["GET"])
457
-def get_exports():
458
-    """ Return a list of the created Exports. """
459
-    result = Export.query.all()
460
-    return jsonify(exports=[e.as_dict for e in result])
461
-
462
-
463
-@app.route("/exports/<int:export_id>", methods=["GET"])
464
-def get_export(export_id: int):
465
-    """ Return an overview for the given Export. """
466
-    e = Export.query.get_or_404(export_id)
467
-
468
-    ss = [s.as_dict for s in e.settlements]
469
-
470
-    return jsonify(export=e.as_dict, settlements=ss)
471
-
472
-
473
-@app.route("/exports", methods=["POST"])
474
-def add_export():
475
-    """ Create an Export, and link all un-exported Settlements to it. """
476
-    # Assert that there are Settlements to be exported.
477
-    s_count = Settlement.query.filter_by(export=None).count()
478
-    if s_count == 0:
479
-        return jsonify(error="No un-exported Settlements."), 403
480
-
481
-    e = Export()
482
-
483
-    db.session.add(e)
484
-    db.session.commit()
485
-
486
-    Settlement.query.filter_by(export=None).update({"export_id": e.export_id})
487
-    db.session.commit()
488
-
489
-    ss = [s.as_dict for s in e.settlements]
490
-
491
-    return jsonify(export=e.as_dict, settlements=ss), 201
5
+from piket_server.flask import app
6
+
7
+import piket_server.routes.general
8
+import piket_server.routes.people
9
+import piket_server.routes.consumptions
10
+import piket_server.routes.consumption_types
11
+import piket_server.routes.settlements
12
+import piket_server.routes.exports
13
+import piket_server.routes.aardbei

+ 679 - 0
piket_server/aardbei_sync.py

@@ -0,0 +1,679 @@
1
+from __future__ import annotations
2
+
3
+import datetime
4
+import json
5
+import logging
6
+import sys
7
+from dataclasses import asdict, dataclass
8
+from enum import Enum
9
+from typing import Any, Dict, List, NewType, Optional, Tuple, Union
10
+
11
+import requests
12
+
13
+from piket_server.flask import db
14
+from piket_server.models import Person
15
+from piket_server.util import fmt_datetime
16
+
17
+# AARDBEI_ENDPOINT = "https://aardbei.app"
18
+AARDBEI_ENDPOINT = "http://localhost:3000"
19
+log = logging.getLogger(__name__)
20
+
21
+ActivityId = NewType("ActivityId", int)
22
+PersonId = NewType("PersonId", int)
23
+MemberId = NewType("MemberId", int)
24
+ParticipantId = NewType("ParticipantId", int)
25
+
26
+
27
+@dataclass(frozen=True)
28
+class AardbeiPerson:
29
+    """
30
+    Contains the data on a Person as exposed by Aardbei.
31
+
32
+    A Person represents a person in the real world, and maps to a Person in the local database.
33
+    """
34
+
35
+    aardbei_id: PersonId
36
+    full_name: str
37
+
38
+    @classmethod
39
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> AardbeiPerson:
40
+        """
41
+        Load from a dictionary provided by Aardbei.
42
+
43
+        >>> AardbeiPerson.from_aardbei_dict(
44
+          {"person": {"aardbei_id": 1, "full_name": "Henkie Kraggelwenk"}}
45
+        )
46
+        AardbeiPerson(aardbei_id=AardbeiId(1), full_name="Henkie Kraggelwenk")
47
+        """
48
+
49
+        d = data["person"]
50
+        return cls(full_name=d["full_name"], aardbei_id=PersonId(d["id"]))
51
+
52
+    @property
53
+    def as_json_dict(self) -> Dict[str, Any]:
54
+        """
55
+        Serialize to a dictionary as provided by Aardbei.
56
+
57
+        >>> AardbeiPerson(aardbei_id=AardbeiId(1), full_name="Henkie Kraggelwenk").as_json_dict
58
+        {"person": {"id": 1, "full_name": "Henkie Kraggelwenk"}}
59
+        """
60
+
61
+        return {"person": {"id": self.aardbei_id, "full_name": self.full_name}}
62
+
63
+
64
+@dataclass(frozen=True)
65
+class AardbeiMember:
66
+    """
67
+    Contains the data on a Member exposed by Aardbei.
68
+
69
+    A Member represents the membership of a Person in a Group in Aardbei.
70
+    """
71
+
72
+    person: AardbeiPerson
73
+    aardbei_id: MemberId
74
+    is_leader: bool
75
+    display_name: str
76
+
77
+    @classmethod
78
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> AardbeiMember:
79
+        """
80
+        Load from a dictionary provided by Aardbei.
81
+
82
+        >>> from_aardbei_dict({
83
+            "member": {
84
+                "person": {
85
+                    "full_name": "Roer Kuggelvork",
86
+                    "id": 2,
87
+                },
88
+                "id": 23,
89
+                "is_leader": False,
90
+                "display_name": "Roer",
91
+            },
92
+        })
93
+        AardbeiMember(
94
+            person=AardbeiPerson(aardbei_id=PersonId(2), full_name="Roer Kuggelvork"),
95
+            aardbei_id=MemberId(23),
96
+            is_leader=False,
97
+            display_name="Roer",
98
+        )
99
+        """
100
+
101
+        d = data["member"]
102
+        person = AardbeiPerson.from_aardbei_dict(d)
103
+        return cls(
104
+            person=person,
105
+            aardbei_id=MemberId(d["id"]),
106
+            is_leader=d["is_leader"],
107
+            display_name=d["display_name"],
108
+        )
109
+
110
+    @property
111
+    def as_json_dict(self) -> Dict[str, Any]:
112
+        """
113
+        Serialize to a dict as provided by Aardbei.
114
+
115
+        >>> AardbeiMember(
116
+            person=AardbeiPerson(aardbei_id=PersonId(2), full_name="Roer Kuggelvork"),
117
+            aardbei_id=MemberId(23),
118
+            is_leader=False,
119
+            display_name="Roer",
120
+        )
121
+        {
122
+            "member": {
123
+                "person": {
124
+                    "full_name": "Roer Kuggelvork",
125
+                    "id": 2,
126
+                },
127
+                "id": 23,
128
+                "is_leader": False,
129
+                "display_name": "Roer",
130
+            }
131
+        }
132
+        """
133
+        res = {
134
+            "id": self.aardbei_id,
135
+            "is_leader": self.is_leader,
136
+            "display_name": self.display_name,
137
+        }
138
+        res.update(self.person.as_json_dict)
139
+        return res
140
+
141
+
142
+@dataclass(frozen=True)
143
+class AardbeiParticipant:
144
+    """
145
+    Represents a Participant as exposed by Aardbei.
146
+
147
+    A Participant represents the participation of a Person (optionally as a Member in a Group) in an Activity.
148
+    """
149
+
150
+    person: AardbeiPerson
151
+    member: Optional[AardbeiMember]
152
+    aardbei_id: ParticipantId
153
+    attending: bool
154
+    is_organizer: bool
155
+    notes: Optional[str]
156
+
157
+    @property
158
+    def name(self) -> str:
159
+        """
160
+        Return the name to show for this Participant.
161
+        This is the display_name if a Member is present, else the Participant's Person's full name.
162
+        """
163
+        if self.member is not None:
164
+            return self.member.display_name
165
+
166
+        return self.person.full_name
167
+
168
+    @classmethod
169
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> AardbeiParticipant:
170
+        """
171
+        Load from a dictionary as provided by Aardbei.
172
+        """
173
+        d = data["participant"]
174
+        person = AardbeiPerson.from_aardbei_dict(d)
175
+
176
+        member: Optional[AardbeiMember] = None
177
+        if d["member"] is not None:
178
+            member = AardbeiMember.from_aardbei_dict(d)
179
+
180
+        aardbei_id = ParticipantId(d["id"])
181
+
182
+        return cls(
183
+            person=person,
184
+            member=member,
185
+            aardbei_id=aardbei_id,
186
+            attending=d["attending"],
187
+            is_organizer=d["is_organizer"],
188
+            notes=d["notes"],
189
+        )
190
+
191
+    @property
192
+    def as_json_dict(self) -> Dict[str, Any]:
193
+        """
194
+        Serialize to a dict as provided by Aardbei.
195
+        """
196
+        res = {
197
+            "participant": {
198
+                "id": self.aardbei_id,
199
+                "attending": self.attending,
200
+                "is_organizer": self.is_organizer,
201
+                "notes": self.notes,
202
+            }
203
+        }
204
+        res.update(self.person.as_json_dict)
205
+        if self.member is not None:
206
+            res.update(self.member.as_json_dict)
207
+
208
+        return res
209
+
210
+
211
+class NoResponseAction(Enum):
212
+    """Represents the "no response action" attribute of Activities in Aardbei."""
213
+
214
+    Present = "present"
215
+    Absent = "absent"
216
+
217
+
218
+@dataclass(frozen=True)
219
+class ResponseCounts:
220
+    """Represents the "response counts" attribute of Activities in Aardbei."""
221
+
222
+    present: int
223
+    absent: int
224
+    unknown: int
225
+
226
+    @classmethod
227
+    def from_aardbei_dict(cls, data: Dict[str, int]) -> ResponseCounts:
228
+        """Load from a dict as provided by Aardbei."""
229
+        return cls(
230
+            present=data["present"], absent=data["absent"], unknown=data["unknown"]
231
+        )
232
+
233
+    @property
234
+    def as_json_dict(self) -> Dict[str, int]:
235
+        """Serialize to a dict as provided by Aardbei."""
236
+        return {"present": self.present, "absent": self.absent, "unknown": self.unknown}
237
+
238
+
239
+@dataclass(frozen=True)
240
+class SparseAardbeiActivity:
241
+    aardbei_id: ActivityId
242
+    name: str
243
+    description: str
244
+    location: str
245
+    start: datetime.datetime
246
+    end: Optional[datetime.datetime]
247
+    deadline: Optional[datetime.datetime]
248
+    reminder_at: Optional[datetime.datetime]
249
+    no_response_action: NoResponseAction
250
+    response_counts: ResponseCounts
251
+
252
+    def distance(self, reference: datetime.datetime) -> datetime.timedelta:
253
+        """Calculate how long ago this Activity ended / how much time until it starts."""
254
+        if self.end is not None:
255
+            if reference > self.start and reference < self.end:
256
+                return datetime.timedelta(seconds=0)
257
+
258
+            elif reference < self.start:
259
+                return self.start - reference
260
+
261
+            elif reference > self.end:
262
+                return reference - self.end
263
+
264
+        if reference > self.start:
265
+            return reference - self.start
266
+
267
+        return self.start - reference
268
+
269
+    @classmethod
270
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> SparseAardbeiActivity:
271
+        """Load from a dict as provided by Aardbei."""
272
+        start: datetime.datetime = datetime.datetime.fromisoformat(
273
+            data["activity"]["start"]
274
+        )
275
+        end: Optional[datetime.datetime] = None
276
+
277
+        if data["activity"]["end"] is not None:
278
+            end = datetime.datetime.fromisoformat(data["activity"]["end"])
279
+
280
+        deadline: Optional[datetime.datetime] = None
281
+        if data["activity"]["deadline"] is not None:
282
+            deadline = datetime.datetime.fromisoformat(data["activity"]["deadline"])
283
+
284
+        reminder_at: Optional[datetime.datetime] = None
285
+        if data["activity"]["reminder_at"] is not None:
286
+            reminder_at = datetime.datetime.fromisoformat(
287
+                data["activity"]["reminder_at"]
288
+            )
289
+
290
+        no_response_action = NoResponseAction(data["activity"]["no_response_action"])
291
+
292
+        response_counts = ResponseCounts.from_aardbei_dict(
293
+            data["activity"]["response_counts"]
294
+        )
295
+
296
+        return cls(
297
+            aardbei_id=ActivityId(data["activity"]["id"]),
298
+            name=data["activity"]["name"],
299
+            description=data["activity"]["description"],
300
+            location=data["activity"]["location"],
301
+            start=start,
302
+            end=end,
303
+            deadline=deadline,
304
+            reminder_at=reminder_at,
305
+            no_response_action=no_response_action,
306
+            response_counts=response_counts,
307
+        )
308
+
309
+    @property
310
+    def as_json_dict(self) -> Dict[str, Any]:
311
+        """Serialize to a dict as provided by Aardbei."""
312
+        return {
313
+            "activity": {
314
+                "id": self.aardbei_id,
315
+                "name": self.name,
316
+                "description": self.description,
317
+                "location": self.location,
318
+                "start": fmt_datetime(self.start),
319
+                "end": fmt_datetime(self.end),
320
+                "deadline": fmt_datetime(self.deadline),
321
+                "reminder_at": fmt_datetime(self.reminder_at),
322
+                "no_response_action": self.no_response_action.value,
323
+                "response_counts": self.response_counts.as_json_dict,
324
+            }
325
+        }
326
+
327
+
328
+@dataclass(frozen=True)
329
+class AardbeiActivity(SparseAardbeiActivity):
330
+    """Contains the data of an Activity as exposed by Aardbei."""
331
+
332
+    participants: List[AardbeiParticipant]
333
+
334
+    @classmethod
335
+    def from_aardbei_dict(cls, data: Dict[str, Any]) -> AardbeiActivity:
336
+        """Load from a dict as provided by Aardbei."""
337
+        # Ugly: This is a copy of the Sparse variant with added participants.
338
+        # This is not ideal, but I don't care enough to fix this right now.
339
+        participants: List[AardbeiParticipant] = [
340
+            AardbeiParticipant.from_aardbei_dict(x)
341
+            for x in data["activity"]["participants"]
342
+        ]
343
+
344
+        start: datetime.datetime = datetime.datetime.fromisoformat(
345
+            data["activity"]["start"]
346
+        )
347
+        end: Optional[datetime.datetime] = None
348
+
349
+        if data["activity"]["end"] is not None:
350
+            end = datetime.datetime.fromisoformat(data["activity"]["end"])
351
+
352
+        deadline: Optional[datetime.datetime] = None
353
+        if data["activity"]["deadline"] is not None:
354
+            deadline = datetime.datetime.fromisoformat(data["activity"]["deadline"])
355
+
356
+        reminder_at: Optional[datetime.datetime] = None
357
+        if data["activity"]["reminder_at"] is not None:
358
+            reminder_at = datetime.datetime.fromisoformat(
359
+                data["activity"]["reminder_at"]
360
+            )
361
+
362
+        no_response_action = NoResponseAction(data["activity"]["no_response_action"])
363
+
364
+        response_counts = ResponseCounts.from_aardbei_dict(
365
+            data["activity"]["response_counts"]
366
+        )
367
+
368
+        return cls(
369
+            aardbei_id=ActivityId(data["activity"]["id"]),
370
+            name=data["activity"]["name"],
371
+            description=data["activity"]["description"],
372
+            location=data["activity"]["location"],
373
+            start=start,
374
+            end=end,
375
+            deadline=deadline,
376
+            reminder_at=reminder_at,
377
+            no_response_action=no_response_action,
378
+            response_counts=response_counts,
379
+            participants=participants,
380
+        )
381
+
382
+    @property
383
+    def as_json_dict(self) -> Dict[str, Any]:
384
+        """Serialize to a dict as provided by Aardbei."""
385
+        res = super().as_json_dict
386
+        res["participants"] = [p.as_json_dict for p in self.participants]
387
+        return res
388
+
389
+
390
+@dataclass(frozen=True)
391
+class AardbeiMatch:
392
+    """Represents a match between a local Person and a Person present in Aardbei's data."""
393
+
394
+    local: Person
395
+    remote: AardbeiMember
396
+
397
+
398
+@dataclass(frozen=True)
399
+class AardbeiLink:
400
+    """Represents a set of differences between the local state and Aardbei's set of people."""
401
+
402
+    matches: List[AardbeiMatch]
403
+    """People that exist on both sides, but aren't linked in the people table."""
404
+    altered_name: List[AardbeiMatch]
405
+    """People that are already linked but changed one of their names."""
406
+    remote_only: List[AardbeiMember]
407
+    """People that only exist on the remote."""
408
+
409
+    @property
410
+    def num_changes(self) -> int:
411
+        """Return the amount of mismatching people between Aardbei and the local state."""
412
+        return len(self.matches) + len(self.altered_name) + len(self.remote_only)
413
+
414
+
415
+class AardbeiSyncError(Enum):
416
+    """Represents errors that might occur when retrieving data from Aardbei."""
417
+
418
+    CantConnect = "connect_fail"
419
+    HTTPError = "http_fail"
420
+
421
+
422
+def get_aardbei_people(
423
+    token: str, endpoint: str = AARDBEI_ENDPOINT
424
+) -> Union[List[AardbeiMember], AardbeiSyncError]:
425
+    """Retrieve the set of People in a Group from Aardbei, and parse this to
426
+    AardbeiPerson objects. Return a AardbeiSyncError if something fails."""
427
+    try:
428
+        resp: requests.Response = requests.get(
429
+            f"{endpoint}/api/groups/0/",
430
+            headers={"Authorization": f"Group {token}"},
431
+        )
432
+        resp.raise_for_status()
433
+
434
+    except requests.ConnectionError as e:
435
+        log.exception("Can't connect to endpoint %s", endpoint)
436
+        return AardbeiSyncError.CantConnect
437
+
438
+    except requests.HTTPError:
439
+        return AardbeiSyncError.HTTPError
440
+
441
+    members = resp.json()["group"]["members"]
442
+
443
+    return [AardbeiMember.from_aardbei_dict(x) for x in members]
444
+
445
+
446
+def match_local_aardbei(aardbei_members: List[AardbeiMember]) -> AardbeiLink:
447
+    """Inspect the local state and compare it with the set of given
448
+    AardbeiMembers (containing AardbeiPersons). Return a AardbeiLink that
449
+    indicates which local people don't match the remote state."""
450
+
451
+    matches: List[AardbeiMatch] = []
452
+    altered_name: List[AardbeiMatch] = []
453
+    remote_only: List[AardbeiMember] = []
454
+
455
+    for member in aardbei_members:
456
+        p: Optional[Person] = Person.query.filter_by(
457
+            aardbei_id=member.person.aardbei_id
458
+        ).one_or_none()
459
+
460
+        if p is not None:
461
+            if (
462
+                p.full_name != member.person.full_name
463
+                or p.display_name != member.display_name
464
+            ):
465
+                altered_name.append(AardbeiMatch(p, member))
466
+
467
+            else:
468
+                logging.info(
469
+                    "OK: %s / %s (L%s/R%s)",
470
+                    p.full_name,
471
+                    p.display_name,
472
+                    p.person_id,
473
+                    p.aardbei_id,
474
+                )
475
+
476
+            continue
477
+
478
+        p = Person.query.filter_by(full_name=member.person.full_name).one_or_none()
479
+
480
+        if p is not None:
481
+            matches.append(AardbeiMatch(p, member))
482
+        else:
483
+            remote_only.append(member)
484
+
485
+    return AardbeiLink(matches, altered_name, remote_only)
486
+
487
+
488
+def link_matches(matches: List[AardbeiMatch]) -> None:
489
+    """
490
+    Update local people to add the remote ID to the local state.
491
+    This only enqueues the changes in the local SQLAlchemy session, committing
492
+    needs to be done separately.
493
+    """
494
+
495
+    for match in matches:
496
+        match.local.aardbei_id = match.remote.person.aardbei_id
497
+        match.local.display_name = match.remote.display_name
498
+        logging.info(
499
+            "Linking local %s (%s) to remote %s (%s)",
500
+            match.local.full_name,
501
+            match.local.person_id,
502
+            match.remote.display_name,
503
+            match.remote.person.aardbei_id,
504
+        )
505
+
506
+        db.session.add(match.local)
507
+
508
+
509
+def create_missing(missing: List[AardbeiMember]) -> None:
510
+    """
511
+    Create local people for all remote people that don't exist locally.
512
+    This only enqueues the changes in the local SQLAlchemy session, committing
513
+    needs to be done separately.
514
+    """
515
+
516
+    for member in missing:
517
+        pnew = Person(
518
+            full_name=member.person.full_name,
519
+            display_name=member.display_name,
520
+            aardbei_id=member.person.aardbei_id,
521
+            active=False,
522
+        )
523
+        logging.info(
524
+            "Creating new person for %s / %s (%s)",
525
+            member.person.full_name,
526
+            member.display_name,
527
+            member.person.aardbei_id,
528
+        )
529
+        db.session.add(pnew)
530
+
531
+
532
+def update_names(matches: List[AardbeiMatch]) -> None:
533
+    """
534
+    Update the local full and display names of people that were already linked
535
+    to a remote person, and who changed names on the remote.
536
+
537
+    This only enqueues the changes in the local SQLAlchemy session, committing
538
+    needs to be done separately.
539
+    """
540
+
541
+    for match in matches:
542
+        p = match.local
543
+        member = match.remote
544
+        aardbei_person = member.person
545
+
546
+        changed = False
547
+
548
+        if p.full_name != aardbei_person.full_name:
549
+            logging.info(
550
+                "Updating %s (L%s/R%s) full name %s to %s",
551
+                aardbei_person.full_name,
552
+                p.person_id,
553
+                aardbei_person.aardbei_id,
554
+                p.full_name,
555
+                aardbei_person.full_name,
556
+            )
557
+            p.full_name = aardbei_person.full_name
558
+            changed = True
559
+
560
+        if p.display_name != member.display_name:
561
+            logging.info(
562
+                "Updating %s (L%s/R%s) display name %s to %s",
563
+                p.full_name,
564
+                p.person_id,
565
+                aardbei_person.aardbei_id,
566
+                p.display_name,
567
+                member.display_name,
568
+            )
569
+            p.display_name = member.display_name
570
+            changed = True
571
+
572
+        assert changed, "got match but didn't update anything"
573
+
574
+        db.session.add(p)
575
+
576
+
577
+def get_activities(
578
+    token: str, endpoint: str = AARDBEI_ENDPOINT
579
+) -> Union[List[SparseAardbeiActivity], AardbeiSyncError]:
580
+    """
581
+    Get the list of activities present on the remote and return these
582
+    activities, ordered by the temporal distance to the current time.
583
+    """
584
+
585
+    result: List[SparseAardbeiActivity] = []
586
+
587
+    for category in ("upcoming", "current", "previous"):
588
+        try:
589
+            resp = requests.get(
590
+                f"{endpoint}/api/groups/0/{category}_activities",
591
+                headers={"Authorization": f"Group {token}"},
592
+            )
593
+
594
+            resp.raise_for_status()
595
+
596
+        except requests.HTTPError as e:
597
+            log.exception(e)
598
+            return AardbeiSyncError.HTTPError
599
+
600
+        except requests.ConnectionError as e:
601
+            log.exception(e)
602
+            return AardbeiSyncError.CantConnect
603
+
604
+        for item in resp.json():
605
+            result.append(SparseAardbeiActivity.from_aardbei_dict(item))
606
+
607
+    now = datetime.datetime.now(datetime.timezone.utc)
608
+    result.sort(key=lambda x: SparseAardbeiActivity.distance(x, now))
609
+    return result
610
+
611
+
612
+def get_activity(
613
+    activity_id: ActivityId, token: str, endpoint: str
614
+) -> Union[AardbeiActivity, AardbeiSyncError]:
615
+    """
616
+    Get all data (including participants) from the remote about one activity
617
+    with a given ID.
618
+    """
619
+
620
+    try:
621
+        resp = requests.get(
622
+            f"{endpoint}/api/activities/{activity_id}",
623
+            headers={"Authorization": f"Group {token}"},
624
+        )
625
+
626
+        resp.raise_for_status()
627
+
628
+    except requests.HTTPError as e:
629
+        log.exception(e)
630
+        return AardbeiSyncError.HTTPError
631
+
632
+    except requests.ConnectionError as e:
633
+        return AardbeiSyncError.CantConnect
634
+
635
+    return AardbeiActivity.from_aardbei_dict(resp.json())
636
+
637
+
638
+def match_activity(activity: AardbeiActivity) -> None:
639
+    """
640
+    Update the local state to have mark all people present at the given
641
+    activity as active, and all other people as inactive.
642
+    """
643
+    ps = activity.participants
644
+    pids: List[PersonId] = [p.person.aardbei_id for p in ps if p.attending]
645
+
646
+    Person.query.update(values={"active": False})
647
+    Person.query.filter(Person.aardbei_id.in_(pids)).update(
648
+        values={"active": True}, synchronize_session="fetch"
649
+    )
650
+
651
+
652
+if __name__ == "__main__":
653
+    logging.basicConfig(level=logging.DEBUG)
654
+
655
+    token = input("Token: ")
656
+    aardbei_people = get_aardbei_people(token)
657
+
658
+    if isinstance(aardbei_people, AardbeiSyncError):
659
+        logging.error("Could not get people: %s", aardbei_people.value)
660
+        sys.exit(1)
661
+
662
+    activities = get_activities(token)
663
+
664
+    if isinstance(activities, AardbeiSyncError):
665
+        logging.error("Could not get activities: %s", activities.value)
666
+        sys.exit(1)
667
+
668
+    link = match_local_aardbei(aardbei_people)
669
+
670
+    link_matches(link.matches)
671
+    create_missing(link.remote_only)
672
+    update_names(link.altered_name)
673
+
674
+    confirm = input("Commit? Y/N")
675
+    if confirm.lower() == "y":
676
+        print("Committing.")
677
+        db.session.commit()
678
+    else:
679
+        print("Not committing.")

+ 9 - 11
piket_server/alembic/env.py

@@ -12,26 +12,24 @@ config = context.config
12 12
 # This line sets up loggers basically.
13 13
 fileConfig(config.config_file_name)
14 14
 
15
-# add your model's MetaData object here
16
-# for 'autogenerate' support
17
-# from myapp import mymodel
18
-# target_metadata = mymodel.Base.metadata
19
-import piket_server
20
-
21
-target_metadata = piket_server.db.Model.metadata
22
-
23 15
 # other values from the config, defined by the needs of env.py,
24 16
 # can be acquired:
25 17
 # my_important_option = config.get_main_option("my_important_option")
26 18
 # ... etc.
27
-from piket_server import CONFIG_DIR, DB_URL
19
+from piket_server.flask import CONFIG_DIR, DB_URL, db
20
+
21
+# add your model's MetaData object here
22
+# for 'autogenerate' support
23
+# from myapp import mymodel
24
+# target_metadata = mymodel.Base.metadata
25
+target_metadata = db.Model.metadata
28 26
 
29 27
 os.makedirs(os.path.expanduser(CONFIG_DIR), mode=0o744, exist_ok=True)
30 28
 
31 29
 config.file_config["alembic"]["sqlalchemy.url"] = DB_URL
32 30
 
33 31
 
34
-def run_migrations_offline():
32
+def run_migrations_offline() -> None:
35 33
     """Run migrations in 'offline' mode.
36 34
 
37 35
     This configures the context with just a URL
@@ -50,7 +48,7 @@ def run_migrations_offline():
50 48
         context.run_migrations()
51 49
 
52 50
 
53
-def run_migrations_online():
51
+def run_migrations_online() -> None:
54 52
     """Run migrations in 'online' mode.
55 53
 
56 54
     In this scenario we need to create an Engine

+ 36 - 0
piket_server/alembic/versions/6a5989118ee3_enable_unique_constraints.py

@@ -0,0 +1,36 @@
1
+"""Enable unique constraints
2
+
3
+Revision ID: 6a5989118ee3
4
+Revises: cca57457a0a6
5
+Create Date: 2019-09-22 17:04:01.945713
6
+
7
+"""
8
+from alembic import op
9
+import sqlalchemy as sa
10
+
11
+
12
+# revision identifiers, used by Alembic.
13
+revision = "6a5989118ee3"
14
+down_revision = "cca57457a0a6"
15
+branch_labels = None
16
+depends_on = None
17
+
18
+
19
+def upgrade():
20
+    with op.batch_alter_table("consumption_types") as batch_op:
21
+        batch_op.create_unique_constraint("uc_consumption_types_name", ["name"])
22
+
23
+    with op.batch_alter_table("people") as batch_op2:
24
+        batch_op2.create_unique_constraint("uc_people_aardbei_id", ["aardbei_id"])
25
+        batch_op2.create_unique_constraint("uc_people_full_name", ["full_name"])
26
+        batch_op2.create_unique_constraint("uc_people_display_name", ["display_name"])
27
+
28
+
29
+def downgrade():
30
+    with op.batch_alter_table("people") as batch_op2:
31
+        batch_op2.drop_constraint("uc_people_display_name", type_="unique")
32
+        batch_op2.drop_constraint("uc_people_full_name", type_="unique")
33
+        batch_op2.drop_constraint("uc_people_aardbei_id", type_="unique")
34
+
35
+    with op.batch_alter_table("consumption_types") as batch_op:
36
+        batch_op.drop_constraint("uc_consumption_types_name", type_="unique")

+ 30 - 0
piket_server/alembic/versions/cca57457a0a6_add_aardbei_fields.py

@@ -0,0 +1,30 @@
1
+"""Add Aardbei fields
2
+
3
+Revision ID: cca57457a0a6
4
+Revises: 2f3a49058a67
5
+Create Date: 2019-09-05 21:38:28.489281
6
+
7
+"""
8
+from alembic import op
9
+import sqlalchemy as sa
10
+
11
+
12
+# revision identifiers, used by Alembic.
13
+revision = "cca57457a0a6"
14
+down_revision = "2f3a49058a67"
15
+branch_labels = None
16
+depends_on = None
17
+
18
+
19
+def upgrade():
20
+    with op.batch_alter_table("people") as batch_op:
21
+        batch_op.alter_column("name", new_column_name="full_name")
22
+        batch_op.add_column(sa.Column("aardbei_id", sa.Integer(), nullable=True))
23
+        batch_op.add_column(sa.Column("display_name", sa.String(), nullable=True))
24
+
25
+
26
+def downgrade():
27
+    with op.batch_alter_table("people") as batch_op:
28
+        batch_op.alter_column("full_name", new_column_name="name")
29
+        batch_op.drop_column("aardbei_id")
30
+        batch_op.drop_column("display_name")

+ 19 - 0
piket_server/flask.py

@@ -0,0 +1,19 @@
1
+"""
2
+Defines the Flask object used to run the server.
3
+"""
4
+
5
+import os
6
+from typing import Any
7
+
8
+from flask import Flask
9
+from flask_sqlalchemy import SQLAlchemy  # type: ignore
10
+
11
+DATA_HOME = os.environ.get("XDG_DATA_HOME", "~/.local/share")
12
+CONFIG_DIR = os.path.join(DATA_HOME, "piket_server")
13
+DB_PATH = os.path.expanduser(os.path.join(CONFIG_DIR, "database.sqlite3"))
14
+DB_URL = f"sqlite:///{DB_PATH}"
15
+
16
+app = Flask("piket_server")
17
+app.config["SQLALCHEMY_DATABASE_URI"] = DB_URL
18
+app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
19
+db: Any = SQLAlchemy(app)

+ 250 - 0
piket_server/models.py

@@ -0,0 +1,250 @@
1
+"""
2
+Defines database models used by the server.
3
+"""
4
+
5
+import datetime
6
+from typing import List, Dict, Any
7
+from collections import defaultdict
8
+
9
+from sqlalchemy import func
10
+from sqlalchemy.exc import SQLAlchemyError
11
+
12
+from piket_server.flask import db
13
+
14
+
15
+class Person(db.Model):
16
+    """ Represents a person to be shown on the lists. """
17
+
18
+    __tablename__ = "people"
19
+
20
+    person_id = db.Column(db.Integer, primary_key=True)
21
+    full_name = db.Column(db.String, nullable=False, unique=True)
22
+    display_name = db.Column(db.String, nullable=True, unique=True)
23
+    aardbei_id = db.Column(db.Integer, nullable=True, unique=True)
24
+    active = db.Column(db.Boolean, nullable=False, default=False)
25
+
26
+    consumptions = db.relationship("Consumption", backref="person", lazy=True)
27
+
28
+    def __repr__(self) -> str:
29
+        return f"<Person {self.person_id}: {self.full_name}>"
30
+
31
+    @property
32
+    def as_dict(self) -> dict:
33
+        return {
34
+            "person_id": self.person_id,
35
+            "active": self.active,
36
+            "full_name": self.full_name,
37
+            "display_name": self.display_name,
38
+            "consumptions": {
39
+                ct.consumption_type_id: Consumption.query.filter_by(person=self)
40
+                .filter_by(settlement=None)
41
+                .filter_by(consumption_type=ct)
42
+                .filter_by(reversed=False)
43
+                .count()
44
+                for ct in ConsumptionType.query.all()
45
+            },
46
+        }
47
+
48
+
49
+class Export(db.Model):
50
+    """ Represents a set of exported Settlements. """
51
+
52
+    __tablename__ = "exports"
53
+
54
+    export_id = db.Column(db.Integer, primary_key=True)
55
+    created_at = db.Column(
56
+        db.DateTime, default=datetime.datetime.utcnow, nullable=False
57
+    )
58
+
59
+    settlements = db.relationship("Settlement", backref="export", lazy=True)
60
+
61
+    @property
62
+    def as_dict(self) -> dict:
63
+        return {
64
+            "export_id": self.export_id,
65
+            "created_at": self.created_at.isoformat(),
66
+            "settlement_ids": [s.settlement_id for s in self.settlements],
67
+        }
68
+
69
+
70
+class Settlement(db.Model):
71
+    """ Represents a settlement of the list. """
72
+
73
+    __tablename__ = "settlements"
74
+
75
+    settlement_id = db.Column(db.Integer, primary_key=True)
76
+    name = db.Column(db.String, nullable=False)
77
+    export_id = db.Column(db.Integer, db.ForeignKey("exports.export_id"), nullable=True)
78
+
79
+    consumptions = db.relationship("Consumption", backref="settlement", lazy=True)
80
+
81
+    def __repr__(self) -> str:
82
+        return f"<Settlement {self.settlement_id}: {self.name}>"
83
+
84
+    @property
85
+    def as_dict(self) -> dict:
86
+        return {
87
+            "settlement_id": self.settlement_id,
88
+            "name": self.name,
89
+            "consumption_summary": self.consumption_summary,
90
+            "unique_people": self.unique_people,
91
+            "per_person_counts": self.per_person_counts,
92
+            "count_info": self.per_person,
93
+        }
94
+
95
+    @property
96
+    def unique_people(self) -> int:
97
+        q = (
98
+            Consumption.query.filter_by(settlement=self)
99
+            .filter_by(reversed=False)
100
+            .group_by(Consumption.person_id)
101
+            .count()
102
+        )
103
+        return q
104
+
105
+    @property
106
+    def consumption_summary(self) -> dict:
107
+        q = (
108
+            Consumption.query.filter_by(settlement=self)
109
+            .filter_by(reversed=False)
110
+            .group_by(Consumption.consumption_type_id)
111
+            .order_by(ConsumptionType.name)
112
+            .outerjoin(ConsumptionType)
113
+            .with_entities(
114
+                Consumption.consumption_type_id,
115
+                ConsumptionType.name,
116
+                func.count(Consumption.consumption_id),
117
+            )
118
+            .all()
119
+        )
120
+
121
+        return {r[0]: {"name": r[1], "count": r[2]} for r in q}
122
+
123
+    @property
124
+    def per_person(self) -> dict:
125
+        # Get keys of seen consumption_types
126
+        c_types = self.consumption_summary.keys()
127
+
128
+        result = {}
129
+        for type in c_types:
130
+            c_type = ConsumptionType.query.get(type)
131
+            result[type] = {"consumption_type": c_type.as_dict, "counts": {}}
132
+
133
+            q = (
134
+                Consumption.query.filter_by(settlement=self)
135
+                .filter_by(reversed=False)
136
+                .filter_by(consumption_type=c_type)
137
+                .group_by(Consumption.person_id)
138
+                .order_by(Person.full_name)
139
+                .outerjoin(Person)
140
+                .with_entities(
141
+                    Person.person_id,
142
+                    Person.full_name,
143
+                    func.count(Consumption.consumption_id),
144
+                )
145
+                .all()
146
+            )
147
+
148
+            for row in q:
149
+                result[type]["counts"][row[0]] = {"name": row[1], "count": row[2]}
150
+
151
+        return result
152
+
153
+    @property
154
+    def per_person_counts(self) -> Dict[int, Any]:
155
+        """
156
+        Output a more usable dict containing for each person in the settlement
157
+        how many of each consumption type was counted.
158
+        """
159
+
160
+        q = (
161
+            Consumption.query.filter_by(settlement=self)
162
+            .filter_by(reversed=False)
163
+            .group_by(Consumption.person_id)
164
+            .group_by(Consumption.consumption_type_id)
165
+            .group_by(Person.full_name)
166
+            .outerjoin(Person)
167
+            .with_entities(
168
+                Consumption.person_id,
169
+                Person.full_name,
170
+                Consumption.consumption_type_id,
171
+                func.count(Consumption.consumption_id),
172
+            )
173
+            .all()
174
+        )
175
+
176
+        res: Dict[int, Any] = defaultdict(dict)
177
+
178
+        for row in q:
179
+            item = res[row[0]]
180
+            item["full_name"] = row[1]
181
+            if not item.get("counts"):
182
+                item["counts"] = {}
183
+
184
+            item["counts"][row[2]] = row[3]
185
+
186
+        return res
187
+
188
+
189
+
190
+
191
+class ConsumptionType(db.Model):
192
+    """ Represents a type of consumption to be counted. """
193
+
194
+    __tablename__ = "consumption_types"
195
+
196
+    consumption_type_id = db.Column(db.Integer, primary_key=True)
197
+    name = db.Column(db.String, nullable=False, unique=True)
198
+    icon = db.Column(db.String)
199
+    active = db.Column(db.Boolean, default=True)
200
+
201
+    consumptions = db.relationship("Consumption", backref="consumption_type", lazy=True)
202
+
203
+    def __repr__(self) -> str:
204
+        return f"<ConsumptionType: {self.name}>"
205
+
206
+    @property
207
+    def as_dict(self) -> dict:
208
+        return {
209
+            "consumption_type_id": self.consumption_type_id,
210
+            "name": self.name,
211
+            "icon": self.icon,
212
+            "active": self.active,
213
+        }
214
+
215
+
216
+class Consumption(db.Model):
217
+    """ Represent one consumption to be counted. """
218
+
219
+    __tablename__ = "consumptions"
220
+
221
+    consumption_id = db.Column(db.Integer, primary_key=True)
222
+    person_id = db.Column(db.Integer, db.ForeignKey("people.person_id"), nullable=True)
223
+    consumption_type_id = db.Column(
224
+        db.Integer,
225
+        db.ForeignKey("consumption_types.consumption_type_id"),
226
+        nullable=False,
227
+    )
228
+    settlement_id = db.Column(
229
+        db.Integer, db.ForeignKey("settlements.settlement_id"), nullable=True
230
+    )
231
+    created_at = db.Column(
232
+        db.DateTime, default=datetime.datetime.utcnow, nullable=False
233
+    )
234
+    reversed = db.Column(db.Boolean, default=False, nullable=False)
235
+
236
+    def __repr__(self) -> str:
237
+        return (
238
+            f"<Consumption: {self.consumption_type.name} for {self.person.full_name}>"
239
+        )
240
+
241
+    @property
242
+    def as_dict(self) -> dict:
243
+        return {
244
+            "consumption_id": self.consumption_id,
245
+            "person_id": self.person_id,
246
+            "consumption_type_id": self.consumption_type_id,
247
+            "settlement_id": self.settlement_id,
248
+            "created_at": self.created_at.isoformat(),
249
+            "reversed": self.reversed,
250
+        }

+ 0 - 0
piket_server/routes/__init__.py


+ 121 - 0
piket_server/routes/aardbei.py

@@ -0,0 +1,121 @@
1
+from typing import Any, Dict, List, Tuple, Union
2
+
3
+from flask import request
4
+
5
+from piket_server.aardbei_sync import (
6
+    AARDBEI_ENDPOINT,
7
+    ActivityId,
8
+    get_activity,
9
+    AardbeiLink,
10
+    AardbeiSyncError,
11
+    create_missing,
12
+    get_aardbei_people,
13
+    match_activity,
14
+    get_activities,
15
+    link_matches,
16
+    match_local_aardbei,
17
+    update_names,
18
+)
19
+from piket_server.flask import app, db
20
+
21
+
22
+def common_prepare_aardbei_sync(
23
+    token: str, endpoint: str
24
+) -> Union[AardbeiSyncError, AardbeiLink]:
25
+    aardbei_people = get_aardbei_people(token, endpoint)
26
+
27
+    if isinstance(aardbei_people, AardbeiSyncError):
28
+        return aardbei_people
29
+
30
+    aardbei_activities = get_activities(token, endpoint)
31
+
32
+    if isinstance(aardbei_activities, AardbeiSyncError):
33
+        return aardbei_activities
34
+
35
+    return match_local_aardbei(aardbei_people)
36
+
37
+
38
+@app.route("/aardbei/diff_people", methods=["POST"])
39
+def aardbei_diff() -> Tuple[Dict[str, Any], int]:
40
+    data: Dict[str, str] = request.json
41
+    link = common_prepare_aardbei_sync(
42
+        data["token"], data.get("endpoint", AARDBEI_ENDPOINT)
43
+    )
44
+
45
+    if isinstance(link, AardbeiSyncError):
46
+        return {"error": link.value}, 503
47
+
48
+    return (
49
+        {
50
+            "num_changes": link.num_changes,
51
+            "new_people": [member.person.full_name for member in link.remote_only],
52
+            "link_existing": [match.local.full_name for match in link.matches],
53
+            "altered_name": [match.local.full_name for match in link.matches],
54
+        },
55
+        200,
56
+    )
57
+
58
+
59
+@app.route("/aardbei/sync_people", methods=["POST"])
60
+def aardbei_apply() -> Union[Tuple[Dict[str, Any], int]]:
61
+    data: Dict[str, str] = request.json
62
+    link = common_prepare_aardbei_sync(
63
+        data["token"], data.get("endpoint", AARDBEI_ENDPOINT)
64
+    )
65
+
66
+    if isinstance(link, AardbeiSyncError):
67
+        return {"error": link.value}, 503
68
+
69
+    link_matches(link.matches)
70
+    create_missing(link.remote_only)
71
+    update_names(link.altered_name)
72
+
73
+    db.session.commit()
74
+
75
+    return (
76
+        {
77
+            "num_changes": link.num_changes,
78
+            "new_people": [member.person.full_name for member in link.remote_only],
79
+            "link_existing": [match.local.full_name for match in link.matches],
80
+            "altered_name": [match.local.full_name for match in link.altered_name],
81
+        },
82
+        200,
83
+    )
84
+
85
+
86
+@app.route("/aardbei/get_activities", methods=["POST"])
87
+def aardbei_get_activities() -> Tuple[Dict[str, object], int]:
88
+    data: Dict[str, str] = request.json
89
+    activities = get_activities(data["token"], data.get("endpoint", AARDBEI_ENDPOINT))
90
+
91
+    if isinstance(activities, AardbeiSyncError):
92
+        return {"error": activities.value}, 503
93
+
94
+    return {"activities": [x.as_json_dict for x in activities]}, 200
95
+
96
+
97
+@app.route("/aardbei/apply_activity", methods=["POST"])
98
+def aardbei_apply_activity() -> Tuple[Dict[str, Any], int]:
99
+    data: Dict[str, Union[str, int]] = request.json
100
+    aid = data["activity_id"]
101
+    token = data["token"]
102
+    endpoint = data["endpoint"]
103
+
104
+    if not isinstance(aid, int):
105
+        return {"error": "nonnumeric_activity_id"}, 400
106
+
107
+    if not isinstance(token, str):
108
+        return {"error": "illtyped_token"}, 400
109
+
110
+    if not isinstance(endpoint, str):
111
+        return {"error": "illtyped_endpoint"}, 400
112
+
113
+    activity = get_activity(activity_id=ActivityId(aid), token=token, endpoint=endpoint)
114
+
115
+    if isinstance(activity, AardbeiSyncError):
116
+        return {"error": activity.value}, 503
117
+
118
+    match_activity(activity)
119
+    db.session.commit()
120
+
121
+    return (activity.as_json_dict, 200)

+ 64 - 0
piket_server/routes/consumption_types.py

@@ -0,0 +1,64 @@
1
+"""
2
+Provides routes related to managing ConsumptionType objects.
3
+"""
4
+
5
+from sqlalchemy.exc import SQLAlchemyError
6
+from flask import jsonify, request
7
+
8
+from piket_server.models import ConsumptionType
9
+from piket_server.flask import app, db
10
+
11
+
12
+@app.route("/consumption_types", methods=["GET"])
13
+def get_consumption_types():
14
+    """ Return a list of currently active consumption types. """
15
+    try:
16
+        active = int(request.args.get("active", 1))
17
+
18
+    except ValueError:
19
+        return {}, 400
20
+
21
+    ctypes = ConsumptionType.query.filter_by(active=active).all()
22
+    result = [ct.as_dict for ct in ctypes]
23
+    return jsonify(consumption_types=result)
24
+
25
+
26
+@app.route("/consumption_types/<int:consumption_type_id>", methods=["GET"])
27
+def get_consumption_type(consumption_type_id: int):
28
+    ct = ConsumptionType.query.get_or_404(consumption_type_id)
29
+
30
+    return jsonify(consumption_type=ct.as_dict)
31
+
32
+
33
+@app.route("/consumption_types", methods=["POST"])
34
+def add_consumption_type():
35
+    """ Add a new ConsumptionType.  """
36
+    json = request.get_json()
37
+
38
+    if not json:
39
+        return jsonify({"error": "Could not parse JSON."}), 400
40
+
41
+    data = json.get("consumption_type") or {}
42
+    ct = ConsumptionType(name=data.get("name"), icon=data.get("icon"))
43
+
44
+    try:
45
+        db.session.add(ct)
46
+        db.session.commit()
47
+    except SQLAlchemyError:
48
+        return jsonify({"error": "Invalid arguments for ConsumptionType."}), 400
49
+
50
+    return jsonify(consumption_type=ct.as_dict), 201
51
+
52
+
53
+@app.route("/consumption_types/<int:consumption_type_id>", methods=["PATCH"])
54
+def activate_consumption_type(consumption_type_id: int):
55
+    ct = ConsumptionType.query.get_or_404(consumption_type_id)
56
+
57
+    data = request.json["consumption_type"]
58
+    new_active = data.get("active", True)
59
+
60
+    ct.active = new_active
61
+    db.session.add(ct)
62
+    db.session.commit()
63
+
64
+    return jsonify(consumption_type=ct.as_dict), 200

+ 36 - 0
piket_server/routes/consumptions.py

@@ -0,0 +1,36 @@
1
+"""
2
+Provides routes related to Consumption objects.
3
+"""
4
+
5
+from flask import jsonify
6
+from sqlalchemy.exc import SQLAlchemyError
7
+
8
+from piket_server.flask import app, db
9
+from piket_server.models import Consumption
10
+
11
+
12
+@app.route("/consumptions/<int:consumption_id>", methods=["DELETE"])
13
+def reverse_consumption(consumption_id: int):
14
+    """ Reverse a consumption. """
15
+    consumption = Consumption.query.get_or_404(consumption_id)
16
+
17
+    if consumption.reversed:
18
+        return (
19
+            jsonify(
20
+                {
21
+                    "error": "Consumption already reversed",
22
+                    "consumption": consumption.as_dict,
23
+                }
24
+            ),
25
+            409,
26
+        )
27
+
28
+    try:
29
+        consumption.reversed = True
30
+        db.session.add(consumption)
31
+        db.session.commit()
32
+
33
+    except SQLAlchemyError:
34
+        return jsonify({"error": "Database error."}), 500
35
+
36
+    return jsonify(consumption=consumption.as_dict), 200

+ 46 - 0
piket_server/routes/exports.py

@@ -0,0 +1,46 @@
1
+"""
2
+Provides routes for managing Export objects.
3
+"""
4
+
5
+from flask import jsonify
6
+from sqlalchemy.exc import SQLAlchemyError
7
+
8
+from piket_server.flask import app, db
9
+from piket_server.models import Export, Settlement
10
+
11
+@app.route("/exports", methods=["GET"])
12
+def get_exports():
13
+    """ Return a list of the created Exports. """
14
+    result = Export.query.all()
15
+    return jsonify(exports=[e.as_dict for e in result])
16
+
17
+
18
+@app.route("/exports/<int:export_id>", methods=["GET"])
19
+def get_export(export_id: int):
20
+    """ Return an overview for the given Export. """
21
+    e = Export.query.get_or_404(export_id)
22
+
23
+    ss = [s.as_dict for s in e.settlements]
24
+
25
+    return jsonify(export=e.as_dict, settlements=ss)
26
+
27
+
28
+@app.route("/exports", methods=["POST"])
29
+def add_export():
30
+    """ Create an Export, and link all un-exported Settlements to it. """
31
+    # Assert that there are Settlements to be exported.
32
+    s_count = Settlement.query.filter_by(export=None).count()
33
+    if s_count == 0:
34
+        return jsonify(error="No un-exported Settlements."), 403
35
+
36
+    e = Export()
37
+
38
+    db.session.add(e)
39
+    db.session.commit()
40
+
41
+    Settlement.query.filter_by(export=None).update({"export_id": e.export_id})
42
+    db.session.commit()
43
+
44
+    ss = [s.as_dict for s in e.settlements]
45
+
46
+    return jsonify(export=e.as_dict, settlements=ss), 201

+ 38 - 0
piket_server/routes/general.py

@@ -0,0 +1,38 @@
1
+"""
2
+Provides general routes.
3
+"""
4
+
5
+from flask import jsonify
6
+
7
+from piket_server.flask import app
8
+from piket_server.models import Consumption
9
+
10
+
11
+@app.route("/ping")
12
+def ping() -> str:
13
+    """ Return a status ping. """
14
+    return "Pong"
15
+
16
+
17
+@app.route("/status")
18
+def status():
19
+    """ Return a status dict with info about the database. """
20
+    unsettled_q = Consumption.query.filter_by(settlement=None).filter_by(reversed=False)
21
+
22
+    unsettled = unsettled_q.count()
23
+
24
+    first = None
25
+    last = None
26
+    if unsettled:
27
+        last = (
28
+            unsettled_q.order_by(Consumption.created_at.desc())
29
+            .first()
30
+            .created_at.isoformat()
31
+        )
32
+        first = (
33
+            unsettled_q.order_by(Consumption.created_at.asc())
34
+            .first()
35
+            .created_at.isoformat()
36
+        )
37
+
38
+    return jsonify({"unsettled": {"amount": unsettled, "first": first, "last": last}})

+ 122 - 0
piket_server/routes/people.py

@@ -0,0 +1,122 @@
1
+"""
2
+Provides routes related to managing Person objects.
3
+"""
4
+
5
+from flask import jsonify, request
6
+from sqlalchemy.exc import SQLAlchemyError
7
+
8
+from piket_server.models import Consumption, Person
9
+from piket_server.flask import app, db
10
+
11
+
12
+@app.route("/people", methods=["GET"])
13
+def get_people():
14
+    """ Return a list of currently known people. """
15
+    people = Person.query.order_by(Person.full_name).all()
16
+    q = Person.query.order_by(Person.full_name)
17
+    if request.args.get("active"):
18
+        active_status = request.args.get("active", type=int)
19
+        q = q.filter_by(active=active_status)
20
+    people = q.all()
21
+    result = [person.as_dict for person in people]
22
+    return jsonify(people=result)
23
+
24
+
25
+@app.route("/people/<int:person_id>", methods=["GET"])
26
+def get_person(person_id: int):
27
+    person = Person.query.get_or_404(person_id)
28
+
29
+    return jsonify(person=person.as_dict)
30
+
31
+
32
+@app.route("/people", methods=["POST"])
33
+def add_person():
34
+    """
35
+    Add a new person.
36
+
37
+    Required parameters:
38
+    - name (str)
39
+    """
40
+    json = request.get_json()
41
+
42
+    if not json:
43
+        return jsonify({"error": "Could not parse JSON."}), 400
44
+
45
+    data = json.get("person") or {}
46
+    person = Person(
47
+        full_name=data.get("full_name"),
48
+        active=data.get("active", False),
49
+        display_name=data.get("display_name", None),
50
+    )
51
+
52
+    try:
53
+        db.session.add(person)
54
+        db.session.commit()
55
+    except SQLAlchemyError:
56
+        return jsonify({"error": "Invalid arguments for Person."}), 400
57
+
58
+    return jsonify(person=person.as_dict), 201
59
+
60
+
61
+@app.route("/people/<int:person_id>", methods=["PATCH"])
62
+def update_person(person_id: int):
63
+    person = Person.query.get_or_404(person_id)
64
+
65
+    data = request.json["person"]
66
+    changed = False
67
+
68
+    if "active" in data:
69
+        person.active = data["active"]
70
+        changed = True
71
+
72
+    if "full_name" in data:
73
+        person.full_name = data["full_name"]
74
+        changed = True
75
+
76
+    if "display_name" in data:
77
+        person.display_name = data["display_name"]
78
+        changed = True
79
+
80
+    if changed:
81
+        db.session.add(person)
82
+        db.session.commit()
83
+
84
+    return jsonify(person=person.as_dict)
85
+
86
+
87
+@app.route("/people/<int:person_id>/add_consumption", methods=["POST"])
88
+def add_consumption(person_id: int):
89
+    person = Person.query.get_or_404(person_id)
90
+
91
+    consumption = Consumption(person=person, consumption_type_id=1)
92
+    try:
93
+        db.session.add(consumption)
94
+        db.session.commit()
95
+    except SQLAlchemyError:
96
+        return (
97
+            jsonify(
98
+                {"error": "Invalid Consumption parameters.", "person": person.as_dict}
99
+            ),
100
+            400,
101
+        )
102
+
103
+    return jsonify(person=person.as_dict, consumption=consumption.as_dict), 201
104
+
105
+
106
+@app.route("/people/<int:person_id>/add_consumption/<int:ct_id>", methods=["POST"])
107
+def add_consumption2(person_id: int, ct_id: int):
108
+    person = Person.query.get_or_404(person_id)
109
+
110
+    consumption = Consumption(person=person, consumption_type_id=ct_id)
111
+    try:
112
+        db.session.add(consumption)
113
+        db.session.commit()
114
+    except SQLAlchemyError:
115
+        return (
116
+            jsonify(
117
+                {"error": "Invalid Consumption parameters.", "person": person.as_dict}
118
+            ),
119
+            400,
120
+        )
121
+
122
+    return jsonify(person=person.as_dict, consumption=consumption.as_dict), 201

+ 49 - 0
piket_server/routes/settlements.py

@@ -0,0 +1,49 @@
1
+"""
2
+Provides routes for managing Settlement objects.
3
+"""
4
+
5
+from sqlalchemy.exc import SQLAlchemyError
6
+from flask import jsonify, request
7
+
8
+from piket_server.flask import app, db
9
+from piket_server.models import Consumption, Settlement
10
+
11
+
12
+@app.route("/settlements", methods=["GET"])
13
+def get_settlements():
14
+    """ Return a list of the active Settlements. """
15
+    result = Settlement.query.all()
16
+    return jsonify(settlements=[s.as_dict for s in result])
17
+
18
+
19
+@app.route("/settlements/<int:settlement_id>", methods=["GET"])
20
+def get_settlement(settlement_id: int):
21
+    """ Show full details for a single Settlement. """
22
+    s = Settlement.query.get_or_404(settlement_id)
23
+
24
+    per_person = s.per_person
25
+
26
+    return jsonify(settlement=s.as_dict, count_info=per_person)
27
+
28
+
29
+@app.route("/settlements", methods=["POST"])
30
+def add_settlement():
31
+    """ Create a Settlement, and link all un-settled Consumptions to it. """
32
+    json = request.get_json()
33
+
34
+    if not json:
35
+        return jsonify({"error": "Could not parse JSON."}), 400
36
+
37
+    data = json.get("settlement") or {}
38
+    s = Settlement(name=data["name"])
39
+
40
+    db.session.add(s)
41
+    db.session.commit()
42
+
43
+    Consumption.query.filter_by(settlement=None).update(
44
+        {"settlement_id": s.settlement_id}
45
+    )
46
+
47
+    db.session.commit()
48
+
49
+    return jsonify(settlement=s.as_dict, count_info=s.per_person)

+ 4 - 3
piket_server/seed.py

@@ -6,7 +6,8 @@ import argparse
6 6
 import csv
7 7
 import os
8 8
 
9
-from piket_server import db, Person, Settlement, ConsumptionType, Consumption
9
+from piket_server.models import Person, Settlement, ConsumptionType, Consumption
10
+from piket_server.flask import db
10 11
 
11 12
 
12 13
 def main():
@@ -52,8 +53,8 @@ def cmd_clear(args) -> None:
52 53
         print("All data removed. Recreating database...")
53 54
         db.create_all()
54 55
 
55
-        from alembic.config import Config
56
-        from alembic import command
56
+        from alembic.config import Config  # type: ignore
57
+        from alembic import command  # type: ignore
57 58
 
58 59
         alembic_cfg = Config(os.path.join(os.path.dirname(__file__), "alembic.ini"))
59 60
         command.stamp(alembic_cfg, "head")

+ 10 - 0
piket_server/util.py

@@ -0,0 +1,10 @@
1
+import datetime
2
+from typing import Optional
3
+
4
+
5
+def fmt_datetime(x: Optional[datetime.datetime]) -> Optional[str]:
6
+    """Format a datetime as ISO 8601, if it's not None."""
7
+    if x is not None:
8
+        return x.isoformat()
9
+
10
+    return None

+ 3 - 2
setup.py

@@ -18,14 +18,15 @@ setup(
18 18
     entry_points={
19 19
         "console_scripts": [
20 20
             "piket-client=piket_client.gui:main",
21
+            "piket-cli=piket_client.cli:cli",
21 22
             "piket-seed=piket_server.seed:main",
22 23
         ]
23 24
     },
24 25
     install_requires=[],
25 26
     extras_require={
26
-        "dev": ["black", "pylint"],
27
+        "dev": ["black", "pylint", "mypy", "isort"],
27 28
         "server": ["Flask", "SQLAlchemy", "Flask-SQLAlchemy", "alembic", "uwsgi"],
28
-        "client": ["PySide2", "qdarkstyle>=2.6.0", "requests", "simpleaudio"],
29
+        "client": ["PySide2", "qdarkstyle>=2.6.0", "requests", "simpleaudio", "click", "prettytable"],
29 30
         "osk": ["dbus-python"],
30 31
         "sentry": ["raven"],
31 32
     },