adding graph timeseries#807
Conversation
dario-coscia
left a comment
There was a problem hiding this comment.
Minor comment on an extra feature, but overall looks great. Be aware that tests are failing
| residuals = [] | ||
|
|
||
| # Iterate over the time steps | ||
| for step in range(1, batch["input"].shape[2]): |
There was a problem hiding this comment.
@GiovanniCanali what do you think to add the pushforward trick here as well? Is it a condition thing or solver thing? Having it is very easy (we just need the no grad option)
There was a problem hiding this comment.
This is definitely a condition-level concern, since gradient computation must be enabled or disabled when evaluating the residual between the model prediction and the target. As you noted, the implementation should be straightforward. Since the forward method is inherited from TimeSeriesCondition, we should consider implementing it directly there.
GiovanniCanali
left a comment
There was a problem hiding this comment.
I know this is still a work in progress, but I have added a few comments below.
@ndem0, please remember to add the appropriate mapping to the Condition factory so that the software automatically dispatches graph time-series conditions and standard time-series conditions to the correct classes.
|
|
||
| return _DataManager(input=graph) | ||
|
|
||
| def evaluate(self, batch, solver): |
There was a problem hiding this comment.
This method appears to be identical to the one defined in TimeSeriesCondition and should therefore be inherited without modification.
| return torch.stack(residuals).as_subclass(torch.Tensor) | ||
|
|
||
| @property | ||
| def input(self): |
There was a problem hiding this comment.
This method appears to be identical to the one defined in TimeSeriesCondition and should therefore be inherited without modification.
| residuals = [] | ||
|
|
||
| # Iterate over the time steps | ||
| for step in range(1, batch["input"].shape[2]): |
There was a problem hiding this comment.
This is definitely a condition-level concern, since gradient computation must be enabled or disabled when evaluating the residual between the model prediction and the target. As you noted, the implementation should be straightforward. Since the forward method is inherited from TimeSeriesCondition, we should consider implementing it directly there.
Description
Added the graph time series condition.
Checklist