Skip to content

Commit

Permalink
Merge pull request #30 from jrzaurin/fix_image_format
Browse files Browse the repository at this point in the history
Fix image format
  • Loading branch information
jrzaurin authored Dec 4, 2020
2 parents 9f61051 + 3913afe commit 2fe4b49
Show file tree
Hide file tree
Showing 15 changed files with 704 additions and 237 deletions.
54 changes: 40 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

<p align="center">
<img width="450" src="docs/figures/widedeep_logo.png">
<img width="300" src="docs/figures/widedeep_logo.png">
</p>

[![Build Status](https://travis-ci.org/jrzaurin/pytorch-widedeep.svg?branch=master)](https://travis-ci.org/jrzaurin/pytorch-widedeep)
Expand All @@ -9,11 +9,7 @@
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)

Platform | Version Support
---------|:---------------
OSX | [![Python 3.6 3.7](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://www.python.org/)
Linux | [![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
[![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)

# pytorch-widedeep

Expand Down Expand Up @@ -88,15 +84,23 @@ as:
<img width="300" src="docs/figures/architecture_2_math.png">
</p>

When using `pytorch-widedeep`, the assumption is that the so called `Wide` and
`deep dense` (this can be either `DeepDense` or `DeepDenseResnet`. See the
documentation and examples folder for more details) components in the figures
are **always** present, while `DeepText text` and `DeepImage` are optional.
Note that each individual component, `wide`, `deepdense` (either `DeepDense`
or `DeepDenseResnet`), `deeptext` and `deepimage`, can be used independently
and in isolation. For example, one could use only `wide`, which is in simply a
linear model.

On the other hand, while I recommend using the `Wide` and `DeepDense` (or
`DeepDenseResnet`) classes in `pytorch-widedeep` to build the `wide` and
`deepdense` component, it is very likely that users will want to use their own
models in the case of the `deeptext` and `deepimage` components. That is
perfectly possible as long as the the custom models have an attribute called
`output_dim` with the size of the last layer of activations, so that
`WideDeep` can be constructed

`pytorch-widedeep` includes standard text (stack of LSTMs) and image
(pre-trained ResNets or stack of CNNs) models. However, the user can use any
custom model as long as it has an attribute called `output_dim` with the size
of the last layer of activations, so that `WideDeep` can be constructed. See
the examples folder or the docs for more information.
(pre-trained ResNets or stack of CNNs) models.

See the examples folder or the docs for more information.


### Installation
Expand Down Expand Up @@ -124,6 +128,28 @@ cd pytorch-widedeep
pip install -e .
```

**Important note for Mac users**: at the time of writing (Dec-2020) the latest
`torch` release is `1.7`. This release has some
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
when running on Mac and the data-loaders will not run in parallel. In
addition, since `python 3.8`, [the `multiprocessing` library start method
changed from `'fork'` to
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will not
run in parallel. Therefore, for Mac users I recommend using `python 3.6` or
`3.7` and `torch <= 1.6` (with the corresponding, consistent version of
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this
versioning in the `setup.py` file since I expect that all these issues are
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip or
directly from github, downgrade `torch` and `torchvision` manually:

```bash
pip install pytorch-widedeep
pip install torch==1.6.0 torchvision==0.7.0
```

None of these issues affect Linux users.

### Quick start

Binary classification with the [adult
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.4.6
0.4.7
Binary file modified docs/figures/widedeep_logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/figures/widedeep_logo_old.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions examples/02_Model_Components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"if we simply numerically encode (label encode or `le`) the values, starting from 1 (we will save 0 for padding, i.e. unseen values)"
"if we simply numerically encode (label encode or `le`) the values:"
]
},
{
Expand All @@ -146,7 +146,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"now, let's see if the two implementations are equivalent"
"Note that in the functioning implementation of the package we start from 1, saving 0 for padding, i.e. unseen values. \n",
"\n",
"Now, let's see if the two implementations are equivalent"
]
},
{
Expand Down Expand Up @@ -261,7 +263,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that even though the input dim is 10, the Embedding layer has 11 weights. This is because we save 0 for padding, which is used for unseen values during the encoding process"
"Note that even though the input dim is 10, the Embedding layer has 11 weights. Again, this is because we save 0 for padding, which is used for unseen values during the encoding process"
]
},
{
Expand Down
106 changes: 86 additions & 20 deletions examples/03_Binary_Classification_with_Defaults.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -591,16 +591,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:05<00:00, 115.33it/s, loss=0.743, metrics={'acc': 0.6205, 'prec': 0.2817}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 168.06it/s, loss=0.545, metrics={'acc': 0.6452, 'prec': 0.3014}]\n",
"epoch 2: 100%|██████████| 611/611 [00:04<00:00, 122.57it/s, loss=0.486, metrics={'acc': 0.7765, 'prec': 0.5517}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.84it/s, loss=0.44, metrics={'acc': 0.783, 'prec': 0.573}] \n",
"epoch 3: 100%|██████████| 611/611 [00:04<00:00, 124.89it/s, loss=0.419, metrics={'acc': 0.8129, 'prec': 0.6753}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.10it/s, loss=0.402, metrics={'acc': 0.815, 'prec': 0.6816}] \n",
"epoch 4: 100%|██████████| 611/611 [00:04<00:00, 126.35it/s, loss=0.393, metrics={'acc': 0.8228, 'prec': 0.7047}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 160.72it/s, loss=0.385, metrics={'acc': 0.8233, 'prec': 0.7024}]\n",
"epoch 5: 100%|██████████| 611/611 [00:04<00:00, 124.33it/s, loss=0.38, metrics={'acc': 0.826, 'prec': 0.702}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 163.43it/s, loss=0.376, metrics={'acc': 0.8264, 'prec': 0.7}] \n"
"epoch 1: 100%|██████████| 611/611 [00:06<00:00, 101.71it/s, loss=0.448, metrics={'acc': 0.792, 'prec': 0.5728}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 171.00it/s, loss=0.366, metrics={'acc': 0.7991, 'prec': 0.5907}]\n",
"epoch 2: 100%|██████████| 611/611 [00:06<00:00, 101.69it/s, loss=0.361, metrics={'acc': 0.8324, 'prec': 0.6817}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 169.36it/s, loss=0.357, metrics={'acc': 0.8328, 'prec': 0.6807}]\n",
"epoch 3: 100%|██████████| 611/611 [00:05<00:00, 102.65it/s, loss=0.352, metrics={'acc': 0.8366, 'prec': 0.691}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 171.49it/s, loss=0.352, metrics={'acc': 0.8361, 'prec': 0.6867}]\n",
"epoch 4: 100%|██████████| 611/611 [00:06<00:00, 101.52it/s, loss=0.347, metrics={'acc': 0.8389, 'prec': 0.6956}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 163.49it/s, loss=0.349, metrics={'acc': 0.8383, 'prec': 0.6906}]\n",
"epoch 5: 100%|██████████| 611/611 [00:07<00:00, 84.91it/s, loss=0.343, metrics={'acc': 0.8405, 'prec': 0.6987}] \n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 142.83it/s, loss=0.347, metrics={'acc': 0.8399, 'prec': 0.6946}]\n"
]
}
],
Expand Down Expand Up @@ -664,22 +664,88 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:05<00:00, 108.62it/s, loss=0.894, metrics={'acc': 0.5182, 'prec': 0.2037}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.44it/s, loss=0.604, metrics={'acc': 0.5542, 'prec': 0.2135}]\n",
"epoch 2: 100%|██████████| 611/611 [00:05<00:00, 106.49it/s, loss=0.51, metrics={'acc': 0.751, 'prec': 0.4614}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.79it/s, loss=0.452, metrics={'acc': 0.7581, 'prec': 0.4898}]\n",
"epoch 3: 100%|██████████| 611/611 [00:05<00:00, 106.66it/s, loss=0.425, metrics={'acc': 0.8031, 'prec': 0.6618}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 160.73it/s, loss=0.405, metrics={'acc': 0.806, 'prec': 0.6686}] \n",
"epoch 4: 100%|██████████| 611/611 [00:05<00:00, 106.58it/s, loss=0.394, metrics={'acc': 0.8185, 'prec': 0.6966}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.55it/s, loss=0.385, metrics={'acc': 0.8196, 'prec': 0.6994}]\n",
"epoch 5: 100%|██████████| 611/611 [00:05<00:00, 107.28it/s, loss=0.38, metrics={'acc': 0.8236, 'prec': 0.7004}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.37it/s, loss=0.375, metrics={'acc': 0.8244, 'prec': 0.7017}]\n"
"epoch 1: 100%|██████████| 611/611 [00:07<00:00, 77.46it/s, loss=0.387, metrics={'acc': 0.8192, 'prec': 0.6576}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 147.78it/s, loss=0.36, metrics={'acc': 0.8216, 'prec': 0.6617}] \n",
"epoch 2: 100%|██████████| 611/611 [00:08<00:00, 74.99it/s, loss=0.358, metrics={'acc': 0.8313, 'prec': 0.6836}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.26it/s, loss=0.355, metrics={'acc': 0.8321, 'prec': 0.6848}]\n",
"epoch 3: 100%|██████████| 611/611 [00:08<00:00, 76.28it/s, loss=0.351, metrics={'acc': 0.8345, 'prec': 0.6889}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.84it/s, loss=0.354, metrics={'acc': 0.8347, 'prec': 0.6887}]\n",
"epoch 4: 100%|██████████| 611/611 [00:07<00:00, 76.71it/s, loss=0.346, metrics={'acc': 0.8374, 'prec': 0.6946}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.80it/s, loss=0.353, metrics={'acc': 0.8369, 'prec': 0.6935}]\n",
"epoch 5: 100%|██████████| 611/611 [00:08<00:00, 73.25it/s, loss=0.343, metrics={'acc': 0.8386, 'prec': 0.6966}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.05it/s, loss=0.352, metrics={'acc': 0.8382, 'prec': 0.6961}]\n"
]
}
],
"source": [
"model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Also mentioning that one could build a model with the individual components independently. For example, a model comprised only by the `wide` component would be simply a linear model. This could be attained by just:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"model = WideDeep(wide=wide)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[Accuracy, Precision])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/611 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:03<00:00, 188.59it/s, loss=0.482, metrics={'acc': 0.771, 'prec': 0.5633}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 236.13it/s, loss=0.423, metrics={'acc': 0.7747, 'prec': 0.5819}]\n",
"epoch 2: 100%|██████████| 611/611 [00:03<00:00, 190.62it/s, loss=0.399, metrics={'acc': 0.8131, 'prec': 0.686}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 221.47it/s, loss=0.387, metrics={'acc': 0.8138, 'prec': 0.6879}]\n",
"epoch 3: 100%|██████████| 611/611 [00:03<00:00, 190.28it/s, loss=0.378, metrics={'acc': 0.8267, 'prec': 0.7149}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 241.12it/s, loss=0.374, metrics={'acc': 0.8255, 'prec': 0.7128}]\n",
"epoch 4: 100%|██████████| 611/611 [00:03<00:00, 183.27it/s, loss=0.37, metrics={'acc': 0.8304, 'prec': 0.7073}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 227.46it/s, loss=0.369, metrics={'acc': 0.8294, 'prec': 0.7061}]\n",
"epoch 5: 100%|██████████| 611/611 [00:03<00:00, 184.28it/s, loss=0.366, metrics={'acc': 0.8315, 'prec': 0.7006}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 239.87it/s, loss=0.366, metrics={'acc': 0.8303, 'prec': 0.6999}]\n"
]
}
],
"source": [
"model.fit(X_wide=X_wide, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
]
}
],
"metadata": {
Expand Down
28 changes: 23 additions & 5 deletions pypi_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)

Platform | Version Support
---------|:---------------
OSX | [![Python 3.6 3.7](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://www.python.org/)
Linux | [![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
[![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)

# pytorch-widedeep

Expand Down Expand Up @@ -57,6 +53,28 @@ cd pytorch-widedeep
pip install -e .
```

**Important note for Mac users**: at the time of writing (Dec-2020) the latest
`torch` release is `1.7`. This release has some
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
when running on Mac and the data-loaders will not run in parallel. In
addition, since `python 3.8`, [the `multiprocessing` library start method
changed from `'fork'` to
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will not
run in parallel. Therefore, for Mac users I recommend using `python 3.6` or
`3.7` and `torch <= 1.6` (with the corresponding, consistent version of
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this
versioning in the `setup.py` file since I expect that all these issues are
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip or
directly from github, downgrade `torch` and `torchvision` manually:

```bash
pip install pytorch-widedeep
pip install torch==1.6.0 torchvision==0.7.0
```

None of these issues affect Linux users.

### Quick start

Binary classification with the [adult
Expand Down
27 changes: 19 additions & 8 deletions pytorch_widedeep/models/_wd_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class WideDeepDataset(Dataset):

def __init__(
self,
X_wide: np.ndarray,
X_deep: np.ndarray,
target: Optional[np.ndarray] = None,
X_wide: Optional[np.ndarray] = None,
X_deep: Optional[np.ndarray] = None,
X_text: Optional[np.ndarray] = None,
X_img: Optional[np.ndarray] = None,
target: Optional[np.ndarray] = None,
transforms: Optional[Any] = None,
):

Expand All @@ -48,10 +48,12 @@ def __init__(
self.transforms_names = []
self.Y = target

def __getitem__(self, idx: int):
# X_wide and X_deep are assumed to be *always* present
X = Bunch(wide=self.X_wide[idx])
X.deepdense = self.X_deep[idx]
def __getitem__(self, idx: int): # noqa: C901
X = Bunch()
if self.X_wide is not None:
X.wide = self.X_wide[idx]
if self.X_deep is not None:
X.deepdense = self.X_deep[idx]
if self.X_text is not None:
X.deeptext = self.X_text[idx]
if self.X_img is not None:
Expand All @@ -68,6 +70,8 @@ def __getitem__(self, idx: int):
# then we need to replicate what Tensor() does -> transpose axis
# and normalize if necessary
if not self.transforms or "ToTensor" not in self.transforms_names:
if xdi.ndim == 2:
xdi = xdi[:, :, None]
xdi = xdi.transpose(2, 0, 1)
if "int" in str(xdi.dtype):
xdi = (xdi / xdi.max()).astype("float32")
Expand All @@ -87,4 +91,11 @@ def __getitem__(self, idx: int):
return X

def __len__(self):
return len(self.X_deep)
if self.X_wide is not None:
return len(self.X_wide)
if self.X_deep is not None:
return len(self.X_deep)
if self.X_text is not None:
return len(self.X_text)
if self.X_img is not None:
return len(self.X_img)
Loading

0 comments on commit 2fe4b49

Please sign in to comment.