"""
.. include:: ../../globals.inc
Overview
--------
The |mod_helpers| module contains helper functions that are used by functions throughout the
|sp| package.
API Reference
-------------
"""
import math
from scipy import stats
from scipy.special import comb
from scipy.stats import rv_discrete, rv_continuous
from math import factorial
import numpy as np
from itertools import product
from collections import defaultdict
### CONSTANTS ###
BIG_INT = 1e100
BIG_FLOAT = 1.0e100
### UTILITY FUNCTIONS ###
[docs]def min_of_dict(d):
"""Determine min value of a dict and return min and argmin (key).
Values must be numeric.
Parameters
----------
d : dict
The dict.
Returns
-------
min_value : float
Minimum value in dict.
min_key : any
Key that attains minimum value.
Raises
------
TypeError
If dict contains a non-numeric value.
"""
min_key = min(d, key=d.get)
min_value = d[min_key]
return min_value, min_key
[docs]def dict_match(d1, d2, require_presence=False, rel_tol=1e-9, abs_tol=0.0):
"""Check whether two dicts have equal keys and values.
A missing key is treated as 0 if the key is present in the other dict,
unless ``require_presence`` is ``True``, in which case the dict must have the
key to count as a match.
Parameters
----------
d1 : node
First dict for comparison.
d2 : node
Second dict for comparison.
require_presence : bool, optional
Set to ``True`` to require dicts to have the same keys, or ``False``
(default) to treat missing keys as 0s.
rel_tol : float
Relative tolerance.
abs_tol : float
Absolute tolerance.
"""
match = True
# Check d1 against d2.
for key in d1.keys():
if key in d2:
if not math.isclose(d1[key], d2[key], rel_tol=rel_tol, abs_tol=abs_tol):
match = False
else:
if not math.isclose(d1[key], 0, rel_tol=rel_tol, abs_tol=abs_tol) \
or require_presence:
match = False
# Check d2 against d1.
for key in d2.keys():
if key in d2:
# We already checked in this case.
pass
else:
if not math.isclose(d2[key], 0, rel_tol=rel_tol, abs_tol=abs_tol) \
or require_presence:
match = False
return match
[docs]def is_iterable(x):
"""Determine whether ``x`` is an iterable or a singleton.
Parameters
----------
x : any
Object to test for iterable vs. singleton.
Returns
-------
bool
``True`` if ``x`` is iterable, ``False`` if it is a singleton.
"""
# First check whether x is a string (because strings act like iterables).
if isinstance(x, str):
return False
else:
try:
_ = iter(x)
# _ = (y for y in x)
except TypeError:
return False
else:
return True
[docs]def is_list(x):
"""Determine whether x is a list.
Parameters
----------
x : any
Object to test for list-ness.
Returns
-------
bool
``True`` if ``x`` is a list, ``False`` otherwise.
"""
return isinstance(x, list)
[docs]def is_set(x):
"""Determine whether x is a set.
Parameters
----------
x : any
Object to test for set-ness.
Returns
-------
bool
``True`` if ``x`` is a set, ``False`` otherwise.
"""
return isinstance(x, set)
[docs]def is_dict(x):
"""Determine whether x is a dict.
Parameters
----------
x : any
Object to test for dict-ness.
Returns
-------
bool
``True`` if ``x`` is a dict, ``False`` otherwise.
"""
return isinstance(x, dict)
[docs]def is_integer(x):
"""Determine whether ``x`` is an integer. Return ``True`` if ``x`` is an int,
or a float that evaluates to an integer, ``False`` otherwise.
Parameters
----------
x : float
Number to check for integrality.
Returns
-------
bool
``True`` if ``x`` is an integer, ``False`` otherwise.
"""
# Check whether x is an int.
if isinstance(x, int):
return True
# Check whether x is a float.
elif isinstance(x, float):
# Check whether x is an integer.
if x.is_integer():
return True
else:
return False
else:
return False
[docs]def is_numeric_string(s):
"""Determine whether ``s`` is a string that represents a number.
Parameters
----------
s : str
String to check for integrality.
Returns
-------
bool
``True`` if ``s`` represents a number, ``False`` otherwise.
"""
if not isinstance(s, str):
return False
else:
# https://stackoverflow.com/q/354038/3453768
try:
float(s)
return True
except ValueError:
return False
[docs]def is_discrete_distribution(distribution):
"""Check whether the given distribution object is discrete.
Works both for ``rv_frozen`` objects and for custom distributions (i.e., subclasses of
``rv_continuous`` and ``rv_discrete``).
See https://stackoverflow.com/a/61530461/3453768.
Parameters
----------
distribution : rv_frozen, rv_continuous, or rv_discrete
The distribution object to check.
Returns
-------
bool
``True`` if the distribution is discrete, ``False`` otherwise.
.. note:: Not reliable if ``distribution`` is not an ``rv_frozen``, ``rv_discrete``,
or ``rv_continuous`` object.
"""
# First check whether distribution is a frozen distribution (in which case
# it has a .dist attribute which must be checked).
if hasattr(distribution, 'dist'):
# Check .dist attribute to determine type.
return isinstance(distribution.dist, rv_discrete)
else:
# Check the object itself for type.
return isinstance(distribution, rv_discrete)
[docs]def is_continuous_distribution(distribution):
"""Check whether the given distribution object is continuous.
Works both for ``rv_frozen`` objects and for custom distributions (i.e., subclasses of
``rv_continuous`` and ``rv_discrete``).
See https://stackoverflow.com/a/61530461/3453768.
Parameters
----------
distribution : rv_frozen, rv_continuous, or rv_discrete
The distribution object to check.
Returns
-------
bool
``True`` if the distribution is continuous, ``False`` otherwise.
.. note:: Not reliable if ``distribution`` is not an ``rv_frozen``, ``rv_discrete``,
or ``rv_continuous`` object.
"""
# First check whether distribution is a frozen distribution (in which case
# it has a .dist attribute which must be checked).
if hasattr(distribution, 'dist'):
# Check .dist attribute to determine type.
return isinstance(distribution.dist, rv_continuous)
else:
# Check the object itself for type.
return isinstance(distribution, rv_continuous)
[docs]def nearest_dict_value(key_to_search, the_dict):
"""Return the value in ``the_dict`` corresponding to the key that is closest
to ``key_to_search``.
Parameters
----------
key_to_search : float
The number to search for among the keys.
the_dict : dict
The dictionary to search.
"""
# https://stackoverflow.com/a/7934624/3453768
return the_dict.get(key_to_search, the_dict[min(the_dict.keys(), key=lambda k: abs(k-key_to_search))])
[docs]def find_nearest(array, values, sorted=False, index=dict()):
"""Determine entries in ``array`` that are closest to each of the
entries in ``values`` and return their indices. Neither array needs to be sorted,
but if ``array`` is sorted and ``sorted`` is set to ``True``, execution will be faster. In dictionary ``index`` a
map to desired indices can be provided, in which case execution will be even faster for specified values. ``array``
and ``values`` need not be the same length.
Parameters
----------
array : ndarray
The array to search for values in.
values : ndarray
The array whose values should be searched for in the other array.
sorted : bool, optional
If ``True``, treats array as sorted, which will make the function execute faster.
index : dict, optional
A dictionary from values to indices, which will make the function execute even faster for
given values.
Returns
-------
ind : ndarray
Array of indices.
"""
array = np.asarray(array)
values = np.array(values, ndmin=1, copy=False)
ind = np.zeros(values.shape)
for v in range(values.size):
ind[v] = index.get(values[v])
if np.isnan(ind[v]):
if sorted:
# https://stackoverflow.com/a/26026189/3453768
idx = np.searchsorted(array, values[v], side="left")
if idx > 0 and (idx == len(array) or math.fabs(values[v] - array[idx-1])
< math.fabs(values[v] - array[idx])):
ind[v] = idx-1
else:
ind[v] = idx
else:
# https://stackoverflow.com/a/2566508/3453768
idx = (np.abs(array - values[v])).argmin()
ind[v] = idx
return ind.astype(int)
### LIST-HANDLING FUNCTIONS ###
[docs]def check_iterable_sizes(iterable_list):
"""Check whether ``iterable_list`` is a list in which every item is an iterable of the
same size *or* a singleton.
**Examples:**
.. testsetup:: *
from stockpyl.helpers import *
.. doctest::
>>> check_iterable_sizes([[5, 3, 1], ('a', 'b', 'c'), 7])
True
>>> check_iterable_sizes([[5, 3, 1], ('a', 'b'), 7])
False
Parameters
----------
iterable_list : list
List to check.
Returns
-------
bool
``True`` if ``iterable_list`` is a list in which every item is an iterable of the
same size *or* a singleton, ``False`` otherwise.
"""
# Build set of lengths of items in list, excluding singletons.
lengths = {len(i) for i in iterable_list if is_iterable(i) and len(i) != 1}
# Check whether lengths contains at most one element.
return len(lengths) <= 1
[docs]def ensure_list_for_time_periods(x, num_periods, var_name=None):
"""Ensure that ``x`` is a list suitable for time-period indexing; if not, create
such a list and return it.
"Suitable for time-period indexing" means that it has length ``num_periods`` + 1,
and the 0th element is ignored.
* If ``x`` is a singleton, return a list consisting of ``num_periods`` copies of ``x`` plus
a ``0`` in the 0th element.
* If ``x`` is a list of length ``num_periods``+1, return ``x``.
* If ``x`` is a list of length ``num_periods``, shift elements to the right by 1 slot,
fill the 0th element with 0, and return new list.
* If ``x`` is a numpy array, convert to list first, then follow above rules.
* Otherwise, raise a ``ValueError``.
**Examples:**
.. testsetup:: *
from stockpyl.helpers import *
.. doctest::
>>> ensure_list_for_time_periods(5, 3)
[0, 5, 5, 5]
>>> ensure_list_for_time_periods([0, 5, 2, 1], 4)
[0, 0, 5, 2, 1]
>>> ensure_list_for_time_periods([5, 2, 1, 3, 2], 4)
[5, 2, 1, 3, 2]
>>> ensure_list_for_time_periods([0, 5, 2, 1, 5], 3)
Traceback (most recent call last):
...
ValueError: x must be a singleton or a list of length num_periods or num_periods+1
Parameters
----------
x : float or list
Object to time-period-ify.
num_periods : int
Number of time periods.
var_name : str, optional
Variable name to use in generating error messages, if desired.
Useful for offloading the type-checking to this function rather than
doing it in the calling function.
Returns
-------
x_new : list
Time-period-ified list.
Raises
------
ValueError
If ``x`` is not a singleton or a list of length ``num_periods`` or ``num_periods`` + 1.
"""
# If numpy array, convert to list.
if isinstance(x, np.ndarray):
x = x.tolist()
# Determine whether x is singleton or iterable.
if is_iterable(x):
if len(x) == num_periods+1:
return x
elif len(x) == num_periods:
return [0] + x
else:
if var_name is None:
vname = 'x'
else:
vname = var_name
raise ValueError('{:s} must be a singleton or a list of length num_periods or num_periods+1'.format(vname))
else:
return [0] + [x] * num_periods
[docs]def ensure_list_for_nodes(x, num_nodes, default=None):
"""Ensure that ``x`` is a list suitable for node indexing; if not, create
such a list and return it.
"Suitable for node indexing" means that it has length ``num_nodes``.
* If ``x`` is a singleton, return a list consisting of ``num_nodes`` copies of ``x``.
* If ``x`` is a list of length ``num_nodes``, return ``x``.
* If ``x`` is ``None`` and ``default`` is provided, return a list consisting of
``num_nodes`` copies of ``default``.
* If ``x`` is ``None`` and ``default`` is not provided, a list consisting of
``num_nodes`` copies of ``None``.
* Otherwise, raise a ``ValueError``.
**Examples:**
.. testsetup:: *
from stockpyl.helpers import *
.. doctest::
>>> ensure_list_for_nodes(5, 3)
[5, 5, 5]
>>> ensure_list_for_nodes([0, 5, 2, 1], 4)
[0, 5, 2, 1]
>>> ensure_list_for_nodes([0, 5, 2, 1], 3)
Traceback (most recent call last):
...
ValueError: x must be a singleton, a list of length num_nodes, or None
Parameters
----------
x : float or list
Object to node-ify.
num_nodes : int
Number of nodes.
default : float, optional
Value to use if ``x`` is ``None``.
Returns
-------
x_new : list
Node-ified list.
Raises
------
ValueError
If ``x`` is not a singleton, a list of length ``num_nodes``, or ``None``.
"""
# Is x None?
if x is None:
return [default] * num_nodes
else:
# Determine whether x is singleton or iterable.
if is_iterable(x):
if len(list(x)) == num_nodes:
return list(x)
else:
raise ValueError('x must be a singleton, a list of length num_nodes, or None')
else:
return [x] * num_nodes
[docs]def build_node_data_dict(attribute_dict, node_order_in_lists, default_values={}):
"""Build a dict of dicts containing data for all nodes and for one or more attributes.
The dict returned, ``data_dict``, is keyed by the node indices.
``data_dict[n]`` is a dict of data for the node with index ``n``, and
``data_dict[n][a]`` is the value of attribute ``a``, where ``a`` is in a key
in ``attribute_dict``.
:func:`~build_node_data_dict` is similar to calling :func:`~ensure_dict_for_nodes` for multiple
attributes simultaneously.
For each attribute ``a`` in ``attribute_dict`` keys, ``attribute_dict[a]``
may take one of several forms. For a given node index ``n`` and attribute ``a``:
* If ``attribute_dict[a]`` is a dict, then ``data_dict[n][a]`` is set to
``attribute_dict[a][n]``. If ``n`` is not a key in ``attribute_dict[a]``
then ``data_dict[n][a]`` is set to ``default_values[a]`` if it exists
and to ``None`` otherwise.
* If ``attribute_dict[a]`` is a singleton, then ``data_dict[n][a]`` is set to
``attribute_dict[a]``.
* If ``attribute_dict[a]`` is a list with the same length as ``node_order_in_lists``,
then the values in the list are assumed to correspond to the node indices in the
order they are specified in ``node_order_in_lists``. That is, ``attribute_dict[a][k]``
is placed in ``data_dict[node_order_in_lists[k]][a]``, for ``k`` in ``range(len(node_order_in_lists))``.
* If ``attribute_dict[a]`` is ``None`` and ``default_values[a]`` is provided,
``data_dict[n][a]`` is set to ``default_values[a]``.
(This is useful for setting default values for attributes that are passed without knowing
whether data is provided for them.)
* If ``attribute_dict[a]`` is ``None`` and ``default_values[a]`` is not provided,
``data_dict[n][a]`` is set to ``None``.
* If ``attribute_dict[a]`` is a list that does *not* have the same length as ``node_order_in_lists``, a ``ValueError`` is raised.
(Exception: The ``demand_list`` and ``probabilities`` attributes of |class_demand_source| are lists,
and are treated as singletons for the purposes of the rules above, unless they contain other lists, in which
case they are treated as normal. For example, ``attribute_dict['demand_list'] = [0, 1, 2, 3]`` will set the ``demand_list`` to
``[0, 1, 2, 3]`` for all nodes. But ``demand_list=[[0, 1, 2, 3], [0, 1, 2, 3], None, None]`` will set the ``demand_list`` to
``[0, 1, 2, 3]`` for nodes 0 and 1 but to ``None`` for nodes 2 and 3.)
Parameters
----------
attribute_dict : dict
Dict whose keys are strings (representing node attributes) and whose values
are dicts, lists, and/or singletons specifying the values of the attributes for the nodes.
node_order_in_lists : list
List of node indices in the order in which the nodes are listed in any
attributes that are lists. (``node_order_in_lists[k]`` is the index of the ``k`` th node.)
default_values : dict
Dict whose keys are strings (in ``attribute_dict.keys()`` ) and whose values
are the default values to use if the attribute is not provided.
Returns
-------
data_dict : dict
The data dict.
Raises
------
ValueError
If ``attribute_dict[a]`` is a list whose length does not equal that of ``node_order_in_lists``.
**Example:**
.. testsetup:: *
from stockpyl.helpers import build_node_data_dict
.. doctest::
>>> attribute_dict = {}
>>> attribute_dict['local_holding_cost'] = 1
>>> attribute_dict['stockout_cost'] = [10, 8, 0]
>>> attribute_dict['demand_mean'] = {1: 0, 3: 50}
>>> attribute_dict['lead_time'] = None
>>> attribute_dict['processing_time'] = None
>>> node_order_in_lists = [3, 2, 1]
>>> default_values = {'lead_time': 0, 'demand_mean': 99}
>>> data_dict = build_node_data_dict(attribute_dict, node_order_in_lists, default_values)
>>> data_dict[1]
{'local_holding_cost': 1, 'stockout_cost': 0, 'demand_mean': 0, 'lead_time': 0, 'processing_time': None}
>>> data_dict[2]
{'local_holding_cost': 1, 'stockout_cost': 8, 'demand_mean': 99, 'lead_time': 0, 'processing_time': None}
>>> data_dict[3]
{'local_holding_cost': 1, 'stockout_cost': 10, 'demand_mean': 50, 'lead_time': 0, 'processing_time': None}
"""
# Initialize data_dict.
data_dict = {n: {} for n in node_order_in_lists}
# Loop through attributes in attribute_names.
for a in attribute_dict:
# What type is ``attribute_dict[a]``?
if attribute_dict[a] is None:
# Use default (if any).
for n in node_order_in_lists:
if a in default_values:
data_dict[n][a] = default_values[a]
else:
data_dict[n][a] = None
elif type(attribute_dict[a]) == dict:
for n in node_order_in_lists:
# Is n a key in attribute_dict[a]?
if n in attribute_dict[a]:
# Yes: use it.
data_dict[n][a] = attribute_dict[a][n]
elif a in default_values:
# No, but default value is provided; use it instead.
data_dict[n][a] = default_values[a]
else:
# No, and no default value is provided; use None.
data_dict[n][a] = None
# If it's a list and a = 'demand_list' or 'probabilities', treat it as a list if any of its elements
# and as a singleton othwerwise.
elif is_iterable(attribute_dict[a]) and \
(a not in ('demand_list', 'probabilities') or any([is_iterable(e) for e in attribute_dict[a]])):
# attribute_dict[a] is a list; check its length.
if len(list(attribute_dict[a])) == len(node_order_in_lists):
for k in range(len(node_order_in_lists)):
data_dict[node_order_in_lists[k]][a] = attribute_dict[a][k]
else:
raise ValueError('attribute_dict[a] must be a singleton, dict, list with the same length as node_indices, or None for all a in attribute_names')
else:
# attribute_dict[a] is a singleton (or is a list-type attribute).
for n in node_order_in_lists:
data_dict[n][a] = attribute_dict[a]
return data_dict
[docs]def ensure_dict_for_nodes(x, node_indices, default=None):
"""Ensure that ``x`` is a dict with node indices as ``keys``; if not, create
such a dict and return it.
* If ``x`` is a dict, return ``x``.
* If ``x`` is a singleton, return a dict with keys equal to ``node_indices``
and values all equal to ``x``.
* If ``x`` is a list with the same length as ``node_indices``, return a dict
with keys equal to ``node_indices`` and values equal to ``x``.
* If ``x`` is ``None`` and ``default`` is provided, return a dict with keys
equal to ``node_indices`` and values all equal to ``default``.
* If ``x`` is ``None`` and ``default`` is not provided, return a dict with
keys equal to ``node_indices`` and values all equal to ``None``.
* Otherwise, raise a ``ValueError``.
Parameters
----------
x : dict, float, or list
Object to node-ify.
node_indices : list
List of node indices. (``node_indices[k]`` is the index of the ``k`` th node.)
default : float, optional
Value to use if ``x`` is ``None``.
Returns
-------
x_new : dict
Node-ified dict.
Raises
------
ValueError
If ``x`` is not a dict, a singleton, a list with the same length as ``node_indices``, or ``None``.
**Examples:**
.. testsetup:: *
from stockpyl.helpers import *
.. doctest::
>>> ensure_dict_for_nodes(5, [0, 1, 2])
{0: 5, 1: 5, 2: 5}
>>> ensure_dict_for_nodes([0, 5, 2], [0, 1, 2])
{0: 0, 1: 5, 2: 2}
>>> ensure_list_for_nodes([0, 5, 2, 1], [0, 1, 2])
Traceback (most recent call last):
...
ValueError: x must be a singleton, dict, list with the same length as node_indices, or None
"""
# Is x None?
if type(x) == dict:
return x
elif x is None:
return {n_ind: default for n_ind in node_indices}
else:
# Determine whether x is singleton or iterable.
if is_iterable(x):
if len(list(x)) == len(node_indices):
return {node_indices[i]: x[i] for i in range(len(x))}
else:
raise ValueError('x must be a singleton, dict, list with the same length as node_indices, or None')
else:
return {node_indices[i]: x for i in range(len(node_indices))}
[docs]def sort_dict_by_keys(d, ascending=True, return_values=True):
"""Sort dict by keys and return sorted list of values or keys, depending
on the value of ``return_values``.
Keys must all be comparable to one another, i.e., all numbers or all strings,
with the exception that ``None`` keys are allowed.
Special handling is included to handle keys that are ``None``.
(``None`` is assumed to come before any other element when sorting in
ascending order.)
**Example:**
.. testsetup:: *
from stockpyl.helpers import *
.. doctest::
>>> the_dict = {'a': 5, None: 14, 'b': 7}
>>> sort_dict_by_keys(the_dict)
[14, 5, 7]
>>> sort_dict_by_keys(the_dict, return_values=False)
[None, 'a', 'b']
Parameters
----------
d : dict
The dict to sort.
ascending : bool, optional
Sort order.
return_values : bool, optional
Set to ``True`` to return a list of the dict's values, ``False`` to
return its keys.
Returns
-------
return_list : list
List of values or keys of ``d``, sorted in order of keys of ``d``.
"""
# Create dict equal to d but without None key.
dict_without_none = {key: d[key] for key in d.keys() if key is not None}
if return_values:
# Build sorted list of values in dict_without_none.
return_list = [value for _, value in sorted(dict_without_none.items(), reverse=not ascending)]
# If original dict had None key, add it back.
if None in d.keys():
if ascending:
return_list.insert(0, d[None])
else:
return_list.append(d[None])
else:
# Build sorted list of keys in dict_without_none.
return_list = [key for key, _ in sorted(dict_without_none.items(), reverse=not ascending)]
# If original dict had None key, add it back.
if None in d.keys():
if ascending:
return_list.insert(0, None)
else:
return_list.append(None)
return return_list
[docs]def sort_nested_dict_by_keys(d, ascending=True, return_values=True):
"""Sort nested dict by its first two levels of keys and return sorted
list of values or keys, depending on the value of ``return_values``.
If ``return_values`` is ``False``, the list returned
contains tuples ``(key1, key2)``, where ``key1`` is the key from the
outer dict and ``key2`` is the key from the inner dict.
Keys must all be comparable to one another, i.e., all numbers or all strings,
with the exception that ``None`` keys are allowed.
Special handling is included to handle keys that are ``None``.
(``None`` is assumed to come before any other element when sorting in
ascending order.)
**Example:**
.. testsetup:: *
from stockpyl.helpers import *
.. doctest::
>>> the_dict = {'a': 5, None: 14, 'b': 7}
>>> sort_dict_by_keys(the_dict)
[14, 5, 7]
>>> sort_dict_by_keys(the_dict, return_values=False)
[None, 'a', 'b']
Parameters
----------
d : dict
The dict to sort.
ascending : bool, optional
Sort order.
return_values : bool, optional
Set to ``True`` to return a list of the dict's values, ``False`` to
return its keys.
Returns
-------
return_list : list
List of values or keys of ``d``, sorted in order of keys of ``d``.
"""
# Replace nested dict with one in which the first two levels are replaced by a tuple.
# Also replace any None keys with -inf so they will always be sorted first (if ascending order).
flattened_dict = {(key1 if key1 is not None else float('-inf'), key2 if key2 is not None else float('-inf')): \
d[key1][key2] for key1 in d.keys() for key2 in d[key1].keys()}
if return_values:
# Build sorted list of values (sorted by keys) in flattened_dict.
return_list = [value for _, value in sorted(flattened_dict.items(), reverse=not ascending)]
else:
# Build sorted list of keys in flattened_dict.
return_list = [key for key, _ in sorted(flattened_dict.items(), reverse=not ascending)]
# Replace -inf with None.
return_list = [(None if key1 == -float('inf') else key1, None if key2 == -float('inf') else key2) \
for key1, key2 in return_list]
return return_list
[docs]def change_dict_key(dict_to_change, old_key, new_key):
"""Change ``old_key`` to ``new_key`` in ``dict_to_change`` (in place).
New key/value pair will appear at end of dict.
Parameters
----------
dict_to_change : dict
The dict.
old_key :
The key to be changed.
new_key :
The key to change to.
Raises
------
KeyError
If ``dict_to_change[old_key]`` is undefined.
"""
# Change key.
# See https://stackoverflow.com/a/4406521/3453768.
dict_to_change[new_key] = dict_to_change.pop(old_key)
[docs]def replace_dict_numeric_string_keys(dict_to_change):
"""Replace any key in ``dict_to_change`` that is a string representing an integer with
the integer itself. Return a new dict.
Works recursively to replace numeric-string keys in any values in ``dict_to_change``, etc.
Parameters
----------
dict_to_change : dict
The dict.
Returns
-------
dict
The new dict.
"""
new_dict = {}
for old_key, v in dict_to_change.items():
# If v is a dict, call this function recursively on it.
if is_dict(v):
v = replace_dict_numeric_string_keys(v)
# Determine new key, if any.
if is_numeric_string(old_key):
# Change key.
new_key = float(old_key)
if is_integer(new_key):
new_key = int(old_key)
else:
new_key = old_key
# Add item to dict.
new_dict[new_key] = v
return new_dict
[docs]def replace_dict_null_keys(dict_to_change):
"""Replace any key in ``dict_to_change`` that equals the string ``'null'`` with ``None``.
Return a new dict.
Works recursively to replace ``'null'``'`` keys in any values in ``dict_to_change``, etc.
Parameters
----------
dict_to_change : dict
The dict.
Returns
-------
dict
The new dict.
"""
new_dict = {}
for old_key, v in dict_to_change.items():
# If v is a dict, call this function recursively on it.
if is_dict(v):
v = replace_dict_null_keys(v)
# Determine new key, if any.
if old_key == 'null':
new_key = None
else:
new_key = old_key
# Add item to dict.
new_dict[new_key] = v
return new_dict
[docs]def compare_unhashable_lists(list1, list2):
"""Determine whether ``list1`` and ``list2`` have the same elements, with the same
counts, not necessarily in the same order. Return ``True`` if they do, ``False`` otherwise.
.. note : Only use this function for lists of unhashable objects (such as |class_node|
and |class_product|). For hashable objects, ``collections.Counter`` is faster, e.g.,
``Counter(list1) == Counter(list2)``.
Parameters
----------
list1 : list
The first list to compare.
list2 : list
The second list to compare.
Returns
-------
bool
``True`` if the two lists have the same elements, with the same counts, not
neceessarily in the same order, ``False`` otherwise.
"""
# https://stackoverflow.com/a/7829388/3453768
if len(list1) != len(list2):
return False
list1 = list(list1) # make a mutable copy
try:
for elem in list2:
list1.remove(elem)
except ValueError:
return False
return not list1
### STATS FUNCTIONS ###
[docs]def convolve_many(arrays):
"""Convolve a list of 1-dimensional float arrays together, using
fast Fourier transforms (FFTs).
The arrays need not have the same length, but each array should
have length at least 1.
If the arrays represent pmfs of discrete random variables
:math:`X_1,\\ldots,X_n`, then the output represents the pmf of
:math:`X_1+\\cdots+X_n`. Assuming the possible values of all of the random
variables are equally spaced with spacing :math:`s`, the possible values of
:math:`X_1+\\cdots+X_n` corresponding to the output are
:math:`\\sum_i\\min X_i,\\ldots,\\sum_i \\max X_i`, with spacing :math:`s`.
Code is adapted from https://stackoverflow.com/a/29236193/3453768.
Parameters
----------
arrays : list of 1-dimensional float arrays
The arrays to convolve.
Returns
-------
convolution : array
Array of elements in the convolution.
**Example:**
Let :math:`X_1 = \\{0, 1, 2\\}` with probabilities :math:`[0.6, 0.3, 0.1]`,
:math:`X_2 = \\{0, 1, 2\\}` with probabilities :math:`[0.5, 0.4, 0.1]`,
:math:`X_3 = \\{0, 1\\}` with probabilities :math:`[0.3, 0.7]`, and
:math:`X_4 = 0` with probability :math:`1`.
.. testsetup:: *
from stockpyl.helpers import *
.. doctest::
>>> convolve_many([[0.6, 0.3, 0.1], [0.5, 0.4, 0.1], [0.3, 0.7], [1.0]])
array([0.09 , 0.327, 0.342, 0.182, 0.052, 0.007])
In other words, :math:`X_1+\\cdots+X_4 = \\{0, 1, \\ldots, 5\\}` with
probabilities :math:`[0.09 , 0.327, 0.342, 0.182, 0.052, 0.007]`.
"""
result_length = 1 + sum((len(a) - 1) for a in arrays)
# Copy each array into a 2d array of the appropriate shape.
rows = np.zeros((len(arrays), result_length))
for i, a in enumerate(arrays):
rows[i, :len(a)] = a
# Transform, take the product, and do the inverse transform
# to get the convolution.
fft_of_rows = np.fft.fft(rows)
fft_of_convolution = fft_of_rows.prod(axis=0)
convolution = np.fft.ifft(fft_of_convolution)
# Assuming real inputs, the imaginary part of the output can be ignored.
convolution = convolution.real
# Ensure values in convolution are >= 0. Sometimes ifft will produce slightly
# negative values due to rounding error.
ROUNDING_TOL = 1.0e-10
for i in range(len(convolution)):
if convolution[i] < -ROUNDING_TOL:
raise ValueError("np.fft.ifft returned negative results in convolve_many()")
elif -ROUNDING_TOL <= convolution[i] < 0:
convolution[i] = 0
return convolution
[docs]def irwin_hall_cdf(x, n):
"""Return cdf of Irwin-Hall distribution, i.e., distribution of sum of ``n``
:math:`U[0,1]` random variables.
See https://en.wikipedia.org/wiki/Irwin%E2%80%93Hall_distribution.
Parameters
----------
x : float
Argument of cdf function.
n : int
Number of :math:`U[0,1]` random variables in the sum.
Returns
-------
F : float
The cdf of ``x``.
"""
F = 0
for k in range(int(np.floor(x)) + 1):
F += ((-1) ** k) * comb(n, k) * (x - k) ** n
F /= factorial(n)
return F
[docs]def sum_of_discretes_distribution(n, lo, hi, p):
"""Return distribution of convolution of ``n`` identical discrete random variables as an
``rv_discrete`` object.
The random variables must have support ``lo``, ``lo`` + 1, ..., ``hi``.
(The convolution will have support ``n * lo``, ``n * lo`` + 1, ..., ``n * hi``).
Parameters
----------
n : int
Number of uniform random variables in the sum.
lo : int
Smallest value of the support of the random variable.
hi : int
Largest value of the support of the random variable.
p : list
Probabilities of each of the values.
Returns
-------
distribution : rv_discrete
The ``rv_discrete`` object.
Raises
------
ValueError
If ``n``, ``lo``, or ``hi`` are not integers.
ValueError
If ``p`` does not have length ``hi - lo + 1``.
"""
# Check whether n, lo, and hi are integers.
if not is_integer(n):
raise ValueError("n must be an integer")
if not is_integer(lo):
raise ValueError("lo must be an integer")
if not is_integer(hi):
raise ValueError("hi must be an integer")
# Check that p has length hi - lo + 1. (If not, it probably means the calling
# function did not add 0s for values that are not part of the support.)
if len(p) != hi - lo + 1:
raise ValueError("p must have length hi - lo + 1. (If the support of the discrete distribution \
is not consecutive integers, you must add zeroes to p for the integers that are not part of the support)")
xk = np.arange(n * lo, n * hi + 1)
pk = convolve_many([p for _ in range(n)])
distribution = stats.rv_discrete(name='sum_of_discretes', values=(xk, pk))
return distribution
[docs]def round_dict_values(the_dict, round_type=None):
"""Round the values in a dict. Return a new dict.
Parameters
----------
the_dict : dict
The dict to round
round_type : string, optional
Set to 'up' to round all values up to next larger integer, 'down' to
round down, 'nearest' to round to nearest integer, or ``None`` to not round at all.
"""
new_dict = {}
for key, value in the_dict.items():
if round_type == 'up':
new_dict[key] = math.ceil(value)
elif round_type == 'down':
new_dict[key] = math.floor(value)
elif round_type == 'nearest':
new_dict[key] = round(value)
else:
new_dict[key] = value
return new_dict