-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathtask.py
More file actions
633 lines (501 loc) · 21.2 KB
/
task.py
File metadata and controls
633 lines (501 loc) · 21.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# See https://peps.python.org/pep-0563/
from __future__ import annotations
import math
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union
from durabletask.entities import DurableEntity, EntityInstanceId, EntityLock, EntityContext
import durabletask.internal.helpers as pbh
import durabletask.internal.orchestrator_service_pb2 as pb
T = TypeVar('T')
TInput = TypeVar('TInput')
TOutput = TypeVar('TOutput')
class OrchestrationContext(ABC):
@property
@abstractmethod
def instance_id(self) -> str:
"""Get the ID of the current orchestration instance.
The instance ID is generated and fixed when the orchestrator function
is scheduled. It can be either auto-generated, in which case it is
formatted as a UUID, or it can be user-specified with any format.
Returns
-------
str
The ID of the current orchestration instance.
"""
pass
@property
@abstractmethod
def version(self) -> Optional[str]:
"""Get the version of the orchestration instance.
This version is set when the orchestration is scheduled and can be used
to determine which version of the orchestrator function is being executed.
Returns
-------
Optional[str]
The version of the orchestration instance, or None if not set.
"""
pass
@property
@abstractmethod
def current_utc_datetime(self) -> datetime:
"""Get the current date/time as UTC.
This date/time value is derived from the orchestration history. It
always returns the same value at specific points in the orchestrator
function code, making it deterministic and safe for replay.
Returns
-------
datetime
The current timestamp in a way that is safe for use by orchestrator functions
"""
pass
@property
@abstractmethod
def is_replaying(self) -> bool:
"""Get the value indicating whether the orchestrator is replaying from history.
This property is useful when there is logic that needs to run only when
the orchestrator function is _not_ replaying. For example, certain
types of application logging may become too noisy when duplicated as
part of orchestrator function replay. The orchestrator code could check
to see whether the function is being replayed and then issue the log
statements when this value is `false`.
Returns
-------
bool
Value indicating whether the orchestrator function is currently replaying.
"""
pass
@abstractmethod
def set_custom_status(self, custom_status: Any) -> None:
"""Set the orchestration instance's custom status.
Parameters
----------
custom_status: Any
A JSON-serializable custom status value to set.
"""
pass
@abstractmethod
def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
"""Create a Timer Task to fire after at the specified deadline.
Parameters
----------
fire_at: datetime.datetime | datetime.timedelta
The time for the timer to trigger or a time delta from now.
Returns
-------
Task
A Durable Timer Task that schedules the timer to wake up the orchestrator
"""
pass
@abstractmethod
def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
input: Optional[TInput] = None,
retry_policy: Optional[RetryPolicy] = None,
tags: Optional[dict[str, str]] = None) -> CompletableTask[TOutput]:
"""Schedule an activity for execution.
Parameters
----------
activity: Union[Activity[TInput, TOutput], str]
A reference to the activity function to call.
input: Optional[TInput]
The JSON-serializable input (or None) to pass to the activity.
retry_policy: Optional[RetryPolicy]
The retry policy to use for this activity call.
tags: Optional[dict[str, str]]
Optional tags to associate with the activity invocation.
Returns
-------
Task
A Durable Task that completes when the called activity function completes or fails.
"""
pass
@abstractmethod
def call_entity(self,
entity: EntityInstanceId,
operation: str,
input: Optional[TInput] = None) -> CompletableTask[Any]:
"""Schedule entity function for execution.
Parameters
----------
entity: EntityInstanceId
The ID of the entity instance to call.
operation: str
The name of the operation to invoke on the entity.
input: Optional[TInput]
The optional JSON-serializable input to pass to the entity function.
Returns
-------
Task
A Durable Task that completes when the called entity function completes or fails.
"""
pass
@abstractmethod
def signal_entity(
self,
entity_id: EntityInstanceId,
operation_name: str,
input: Optional[TInput] = None
) -> None:
"""Signal an entity function for execution.
Parameters
----------
entity_id: EntityInstanceId
The ID of the entity instance to signal.
operation_name: str
The name of the operation to invoke on the entity.
input: Optional[TInput]
The optional JSON-serializable input to pass to the entity function.
"""
pass
@abstractmethod
def lock_entities(self, entities: list[EntityInstanceId]) -> CompletableTask[EntityLock]:
"""Creates a Task object that locks the specified entity instances.
The locks will be acquired the next time the orchestrator yields.
Best practice is to immediately yield this Task and enter the returned EntityLock.
The lock is released when the EntityLock is exited.
Parameters
----------
entities: list[EntityInstanceId]
The list of entity instance IDs to lock.
Returns
-------
EntityLock
A context manager object that releases the locks when exited.
"""
pass
@abstractmethod
def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput], str], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None,
version: Optional[str] = None) -> CompletableTask[TOutput]:
"""Schedule sub-orchestrator function for execution.
Parameters
----------
orchestrator: Orchestrator[TInput, TOutput]
A reference to the orchestrator function to call.
input: Optional[TInput]
The optional JSON-serializable input to pass to the orchestrator function.
instance_id: Optional[str]
A unique ID to use for the sub-orchestration instance. If not specified, a
random UUID will be used.
retry_policy: Optional[RetryPolicy]
The retry policy to use for this sub-orchestrator call.
Returns
-------
Task
A Durable Task that completes when the called sub-orchestrator completes or fails.
"""
pass
# TOOD: Add a timeout parameter, which allows the task to be canceled if the event is
# not received within the specified timeout. This requires support for task cancellation.
@abstractmethod
def wait_for_external_event(self, name: str) -> CompletableTask:
"""Wait asynchronously for an event to be raised with the name `name`.
Parameters
----------
name : str
The event name of the event that the task is waiting for.
Returns
-------
Task[TOutput]
A Durable Task that completes when the event is received.
"""
pass
@abstractmethod
def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:
"""Continue the orchestration execution as a new instance.
Parameters
----------
new_input : Any
The new input to use for the new orchestration instance.
save_events : bool
A flag indicating whether to add any unprocessed external events in the new orchestration history.
"""
pass
@abstractmethod
def new_uuid(self) -> str:
"""Create a new UUID that is safe for replay within an orchestration or operation.
The default implementation of this method creates a name-based UUID
using the algorithm from RFC 4122 §4.3. The name input used to generate
this value is a combination of the orchestration instance ID, the current UTC datetime,
and an internally managed counter.
Returns
-------
str
New UUID that is safe for replay within an orchestration or operation.
"""
pass
@abstractmethod
def _exit_critical_section(self) -> None:
pass
class FailureDetails:
def __init__(self, message: str, error_type: str, stack_trace: Optional[str]):
self._message = message
self._error_type = error_type
self._stack_trace = stack_trace
@property
def message(self) -> str:
return self._message
@property
def error_type(self) -> str:
return self._error_type
@property
def stack_trace(self) -> Optional[str]:
return self._stack_trace
class TaskFailedError(Exception):
"""Exception type for all orchestration task failures."""
def __init__(self, message: str, details: Union[pb.TaskFailureDetails, Exception]):
super().__init__(message)
if isinstance(details, Exception):
details = pbh.new_failure_details(details)
self._details = FailureDetails(
details.errorMessage,
details.errorType,
details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None)
@property
def details(self) -> FailureDetails:
return self._details
class NonDeterminismError(Exception):
pass
class OrchestrationStateError(Exception):
pass
class Task(ABC, Generic[T]):
"""Abstract base class for asynchronous tasks in a durable orchestration."""
_result: T
_exception: Optional[TaskFailedError]
_parent: Optional[CompositeTask[T]]
def __init__(self) -> None:
super().__init__()
self._is_complete = False
self._exception = None
self._parent = None
@property
def is_complete(self) -> bool:
"""Returns True if the task has completed, False otherwise."""
return self._is_complete
@property
def is_failed(self) -> bool:
"""Returns True if the task has failed, False otherwise."""
return self._exception is not None
def get_result(self) -> T:
"""Returns the result of the task."""
if not self._is_complete:
raise ValueError('The task has not completed.')
elif self._exception is not None:
raise self._exception
return self._result
def get_exception(self) -> TaskFailedError:
"""Returns the exception that caused the task to fail."""
if self._exception is None:
raise ValueError('The task has not failed.')
return self._exception
class CompositeTask(Task[T]):
"""A task that is composed of other tasks."""
_tasks: list[Task]
def __init__(self, tasks: list[Task]):
super().__init__()
self._tasks = tasks
self._completed_tasks = 0
self._failed_tasks = 0
for task in tasks:
task._parent = self
if task.is_complete:
self.on_child_completed(task)
def get_tasks(self) -> list[Task]:
return self._tasks
@abstractmethod
def on_child_completed(self, task: Task[T]):
pass
class WhenAllTask(CompositeTask[list[T]]):
"""A task that completes when all of its child tasks complete."""
def __init__(self, tasks: list[Task[T]]):
super().__init__(tasks)
self._completed_tasks = 0
self._failed_tasks = 0
@property
def pending_tasks(self) -> int:
"""Returns the number of tasks that have not yet completed."""
return len(self._tasks) - self._completed_tasks
def on_child_completed(self, task: Task[T]):
if self.is_complete:
raise ValueError('The task has already completed.')
self._completed_tasks += 1
if task.is_failed and self._exception is None:
self._exception = task.get_exception()
self._is_complete = True
if self._completed_tasks == len(self._tasks):
# The order of the result MUST match the order of the tasks provided to the constructor.
self._result = [task.get_result() for task in self._tasks]
self._is_complete = True
def get_completed_tasks(self) -> int:
return self._completed_tasks
class CompletableTask(Task[T]):
def __init__(self):
super().__init__()
self._retryable_parent = None
def complete(self, result: T):
if self._is_complete:
raise ValueError('The task has already completed.')
self._result = result
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)
def fail(self, message: str, details: Union[Exception, pb.TaskFailureDetails]):
if self._is_complete:
raise ValueError('The task has already completed.')
self._exception = TaskFailedError(message, details)
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)
class RetryableTask(CompletableTask[T]):
"""A task that can be retried according to a retry policy."""
def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction,
start_time: datetime, is_sub_orch: bool) -> None:
super().__init__()
self._action = action
self._retry_policy = retry_policy
self._attempt_count = 1
self._start_time = start_time
self._is_sub_orch = is_sub_orch
def increment_attempt_count(self) -> None:
self._attempt_count += 1
def compute_next_delay(self) -> Optional[timedelta]:
if self._attempt_count >= self._retry_policy.max_number_of_attempts:
return None
retry_expiration: datetime = datetime.max
if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max:
retry_expiration = self._start_time + self._retry_policy.retry_timeout
if self._retry_policy.backoff_coefficient is None:
backoff_coefficient = 1.0
else:
backoff_coefficient = self._retry_policy.backoff_coefficient
if datetime.utcnow() < retry_expiration:
next_delay_f = math.pow(backoff_coefficient, self._attempt_count - 1) * self._retry_policy.first_retry_interval.total_seconds()
if self._retry_policy.max_retry_interval is not None:
next_delay_f = min(next_delay_f, self._retry_policy.max_retry_interval.total_seconds())
return timedelta(seconds=next_delay_f)
return None
class TimerTask(CompletableTask[T]):
def __init__(self) -> None:
super().__init__()
def set_retryable_parent(self, retryable_task: RetryableTask):
self._retryable_parent = retryable_task
class WhenAnyTask(CompositeTask[Task]):
"""A task that completes when any of its child tasks complete."""
def __init__(self, tasks: list[Task]):
super().__init__(tasks)
def on_child_completed(self, task: Task):
# The first task to complete is the result of the WhenAnyTask.
if not self.is_complete:
self._is_complete = True
self._result = task
def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]:
"""Returns a task that completes when all of the provided tasks complete or when one of the tasks fail."""
return WhenAllTask(tasks)
def when_any(tasks: list[Task]) -> WhenAnyTask:
"""Returns a task that completes when any of the provided tasks complete or fail."""
return WhenAnyTask(tasks)
class ActivityContext:
def __init__(self, orchestration_id: str, task_id: int):
self._orchestration_id = orchestration_id
self._task_id = task_id
@property
def orchestration_id(self) -> str:
"""Get the ID of the orchestration instance that scheduled this activity.
Returns
-------
str
The ID of the current orchestration instance.
"""
return self._orchestration_id
@property
def task_id(self) -> int:
"""Get the task ID associated with this activity invocation.
The task ID is an auto-incrementing integer that is unique within
the scope of the orchestration instance. It can be used to distinguish
between multiple activity invocations that are part of the same
orchestration instance.
Returns
-------
str
The ID of the current orchestration instance.
"""
return self._task_id
# Orchestrators are generators that yield tasks, receive any type, and return TOutput
Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task[Any], Any, TOutput], TOutput]]
# Activities are simple functions that can be scheduled by orchestrators
Activity = Callable[[ActivityContext, TInput], TOutput]
Entity = Union[Callable[[EntityContext, TInput], TOutput], type[DurableEntity]]
class RetryPolicy:
"""Represents the retry policy for an orchestration or activity function."""
def __init__(self, *,
first_retry_interval: timedelta,
max_number_of_attempts: int,
backoff_coefficient: Optional[float] = 1.0,
max_retry_interval: Optional[timedelta] = None,
retry_timeout: Optional[timedelta] = None):
"""Creates a new RetryPolicy instance.
Parameters
----------
first_retry_interval : timedelta
The retry interval to use for the first retry attempt.
max_number_of_attempts : int
The maximum number of retry attempts.
backoff_coefficient : Optional[float]
The backoff coefficient to use for calculating the next retry interval.
max_retry_interval : Optional[timedelta]
The maximum retry interval to use for any retry attempt.
retry_timeout : Optional[timedelta]
The maximum amount of time to spend retrying the operation.
"""
# validate inputs
if first_retry_interval < timedelta(seconds=0):
raise ValueError('first_retry_interval must be >= 0')
if max_number_of_attempts < 1:
raise ValueError('max_number_of_attempts must be >= 1')
if backoff_coefficient is not None and backoff_coefficient < 1:
raise ValueError('backoff_coefficient must be >= 1')
if max_retry_interval is not None and max_retry_interval < timedelta(seconds=0):
raise ValueError('max_retry_interval must be >= 0')
if retry_timeout is not None and retry_timeout < timedelta(seconds=0):
raise ValueError('retry_timeout must be >= 0')
self._first_retry_interval = first_retry_interval
self._max_number_of_attempts = max_number_of_attempts
self._backoff_coefficient = backoff_coefficient
self._max_retry_interval = max_retry_interval
self._retry_timeout = retry_timeout
@property
def first_retry_interval(self) -> timedelta:
"""The retry interval to use for the first retry attempt."""
return self._first_retry_interval
@property
def max_number_of_attempts(self) -> int:
"""The maximum number of retry attempts."""
return self._max_number_of_attempts
@property
def backoff_coefficient(self) -> Optional[float]:
"""The backoff coefficient to use for calculating the next retry interval."""
return self._backoff_coefficient
@property
def max_retry_interval(self) -> Optional[timedelta]:
"""The maximum retry interval to use for any retry attempt."""
return self._max_retry_interval
@property
def retry_timeout(self) -> Optional[timedelta]:
"""The maximum amount of time to spend retrying the operation."""
return self._retry_timeout
def get_entity_name(fn: Entity) -> str:
if hasattr(fn, "__durable_entity_name__"):
return getattr(fn, "__durable_entity_name__")
if isinstance(fn, type) and issubclass(fn, DurableEntity):
return fn.__name__
return get_name(fn)
def get_name(fn: Callable) -> str:
"""Returns the name of the provided function"""
name = fn.__name__
if name == '<lambda>':
raise ValueError('Cannot infer a name from a lambda function. Please provide a name explicitly.')
return name